This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 64911ab5da [Runtime] Implemented Datatype.itemsize() (#16880) 64911ab5da is described below commit 64911ab5da3640be4d9fb675513e57b742e188b1 Author: Wuwei Lin <wu...@apache.org> AuthorDate: Sat Apr 13 18:33:12 2024 -0700 [Runtime] Implemented Datatype.itemsize() (#16880) * [Runtime] Implemented Datatype.itemsize() --- python/tvm/_ffi/runtime_ctypes.py | 14 ++++++++++++ python/tvm/dlight/gpu/gemv.py | 2 +- python/tvm/dlight/gpu/low_batch_gemv.py | 8 +++---- tests/python/ir/test_dtype.py | 40 +++++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index dc5582d045..099cbe972a 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -212,6 +212,20 @@ class DataType(ctypes.Structure): def __ne__(self, other): return not self.__eq__(other) + def itemsize(self): + """Get the number of bytes of a single element of this data type. When the number of lanes + is greater than 1, the itemsize is the size of the vector type. + + Returns + ------- + itemsize : int + The number of bytes of a single element of this data type + """ + lanes_as_int = ctypes.c_int16(self.lanes).value + if lanes_as_int < 0: + raise ValueError("Cannot determine itemsize for scalable vector types") + return (self.bits * self.lanes + 7) // 8 + if ml_dtypes is not None: DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16" diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index c1ce876620..644f4e6dfa 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -57,7 +57,7 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): def get_bytes(dtype: Union[DataType, str]) -> int: if isinstance(dtype, str): dtype = DataType(dtype) - return dtype.bits * dtype.lanes // 8 + return dtype.itemsize() def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index 9a92c9e0e9..696722c3f0 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """A rule for low-batch GEMM / decode-GEMM using GEMV schedule.""" -import re from functools import reduce from typing import List, Optional, Set, Union @@ -55,10 +54,9 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): def get_bytes(dtype: Union[DataType, str]) -> int: - num = re.findall(r"\d+", dtype) - if len(num) != 1: - raise ValueError(f"Cannot get bytes from {dtype}") - return int(num[0]) // 8 + if isinstance(dtype, str): + dtype = DataType(dtype) + return dtype.itemsize() def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: diff --git a/tests/python/ir/test_dtype.py b/tests/python/ir/test_dtype.py new file mode 100644 index 0000000000..77cd1d7e4b --- /dev/null +++ b/tests/python/ir/test_dtype.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test data type related API""" +import tvm +from tvm import DataType +import tvm.testing +import pytest + + +@pytest.mark.parametrize( + "dtype_str, expected_size", + [("float32", 4), ("float32x4", 16), ("e5m2_float8x4", 4), ("uint8", 1)], +) +def test_dtype_itemsize(dtype_str, expected_size): + dtype = DataType(dtype_str) + assert dtype.itemsize() == expected_size + + +@pytest.mark.parametrize("dtype_str", [("int32xvscalex4")]) +def test_dtype_itemmize_error(dtype_str): + with pytest.raises(ValueError): + size = DataType(dtype_str).itemsize() + + +if __name__ == "__main__": + tvm.testing.main()