This is an automated email from the ASF dual-hosted git repository. syfeng pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 6252fa5802 [TIR] Enhance CLZ intrinsic support (#16952) 6252fa5802 is described below commit 6252fa5802c94df522306519da94b874b3a45eda Author: Siyuan Feng <hzfen...@sjtu.edu.cn> AuthorDate: Tue Apr 30 14:14:44 2024 +0800 [TIR] Enhance CLZ intrinsic support (#16952) --- .github/workflows/main.yml | 2 + src/target/intrin_rule.h | 18 +++++-- src/target/source/intrin_rule_cuda.cc | 12 +++++ src/target/source/intrin_rule_metal.cc | 3 ++ src/target/source/intrin_rule_opencl.cc | 3 ++ src/tir/ir/data_type_rewriter.cc | 6 ++- .../codegen/test_target_codegen_gpu_common.py | 55 ++++++++++++++++++++++ 7 files changed, 94 insertions(+), 5 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d63af560d7..759acd1fa5 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -77,6 +77,8 @@ jobs: - name: Minimal Metal Compile-and-Run shell: bash -l {0} run: >- + python -m pytest -v -s 'tests/python/codegen/test_target_codegen_metal.py' + python -m pytest -v -s 'tests/python/codegen/test_target_codegen_gpu_common.py' python -m pytest -v -s 'tests/python/codegen/test_gpu_codegen_allreduce.py::test_allreduce_sum[dims0-metal]' # - name: Test iOS RPC # shell: bash -l {0} diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 2695c43173..ea8ccd98b1 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -53,8 +53,13 @@ struct Direct { std::string operator()(DataType t, std::string name) const { return name; } }; -// Call pure extern function. -template <typename T> +/*! + * \brief Dispatch pure extern function. + * \param e The call expression. + * \tparam T The function to dispatch. + * \tparam dtype_from_arg Whether the dtype is from the first argument or the call node + */ +template <typename T, bool dtype_from_arg = false> inline PrimExpr DispatchPureExtern(const PrimExpr& e) { const CallNode* call = e.as<CallNode>(); ICHECK(call != nullptr); @@ -64,7 +69,14 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) { ICHECK(op != nullptr); std::string name = op->name; ICHECK_EQ(name.substr(0, 4), "tir."); - name = T()(call->dtype, name.substr(4)); + DataType dtype; + if (dtype_from_arg) { + ICHECK_EQ(call->args.size(), 1U); + dtype = call->args[0].dtype(); + } else { + dtype = call->dtype; + } + name = T()(dtype, name.substr(4)); if (name.length() != 0) { Array<PrimExpr> new_args = {StringImm(name)}; diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 95fbf7f1a5..79ea7a458f 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -54,6 +54,15 @@ struct CUDAMath { } } else if (t.is_bfloat16()) { return 'h' + name; + } else if (t.is_int() || t.is_uint()) { + switch (t.bits()) { + case 32: + return "__" + name; + case 64: + return "__" + name + "ll"; + default: + return ""; + } } return ""; } @@ -133,6 +142,9 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args); } +TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>( + "cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath, /*dtype_from_arg=*/true>); + TVM_REGISTER_OP("tir.floor") .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>); diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index 50685f6ef2..b7561e8671 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -52,6 +52,9 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), metal_args); } +TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", + DispatchPureExtern<Direct>); + TVM_REGISTER_OP("tir.floor") .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>); diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 94ab9d8b9d..bd9e148b18 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -31,6 +31,9 @@ namespace codegen { namespace intrin { using tir::FLowerIntrinsic; +TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", + DispatchPureExtern<Direct>); + TVM_REGISTER_OP("tir.floor") .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>); diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index a613b8d4bb..c03e19137e 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -238,10 +238,12 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { } else if (op->op.same_as(Op::Get("tir.clz"))) { DataType before_dtype = before->args[0]->dtype; DataType after_dtype = op->args[0]->dtype; - CHECK(before_dtype.is_int() && (before_dtype.bits() == 32 || before_dtype.bits() == 64)) + CHECK((before_dtype.is_int() || before_dtype.is_uint()) && + (before_dtype.bits() == 32 || before_dtype.bits() == 64)) << "clz only supports 32 or 64 bit integer types, but get type before legalizing: " << before_dtype; - CHECK(after_dtype.is_int() && (after_dtype.bits() == 32 || after_dtype.bits() == 64)) + CHECK((after_dtype.is_int() || after_dtype.is_uint()) && + (after_dtype.bits() == 32 || after_dtype.bits() == 64)) << "clz only supports 32 or 64 bit integer types, but get type after legalizing: " << after_dtype; return e - after_dtype.bits() + before_dtype.bits(); diff --git a/tests/python/codegen/test_target_codegen_gpu_common.py b/tests/python/codegen/test_target_codegen_gpu_common.py new file mode 100644 index 0000000000..2941f366a4 --- /dev/null +++ b/tests/python/codegen/test_target_codegen_gpu_common.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from functools import partial + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm import te + + +@tvm.testing.requires_gpu +@tvm.testing.parametrize_targets("cuda", "metal", "vulkan -supports_int64=1", "opencl") +@pytest.mark.parametrize("dtype", ["int32", "uint32", "int64", "uint64"]) +def test_int_intrin(target, dev, dtype): + test_funcs = [ + (tvm.tir.clz, lambda x, dtype: int(dtype[-2:]) - (len(bin(x)) - 2)), + ] + + def run_test(tvm_intrin, np_func, dtype): + n = 128 + A = te.placeholder((n,), name="A", dtype=dtype) + B = te.compute(A.shape, lambda *i: tvm_intrin(A(*i)), name="B") + func = te.create_prim_func([A, B]) + sch = tvm.tir.Schedule(func) + (x,) = sch.get_loops(sch.get_block("B")) + sch.bind(x, "threadIdx.x") + f = tvm.build(sch.mod, target=target) + a = tvm.nd.array(np.random.randint(0, 100000, size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.zeros(shape=(n,)).astype(B.dtype), dev) + f(a, b) + ref = np.vectorize(partial(np_func, dtype=dtype))(a.numpy()) + tvm.testing.assert_allclose(b.numpy(), ref) + + for func in test_funcs: + run_test(*func, dtype) + + +if __name__ == "__main__": + tvm.testing.main()