This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 64911ab5da [Runtime] Implemented Datatype.itemsize() (#16880)
64911ab5da is described below

commit 64911ab5da3640be4d9fb675513e57b742e188b1
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Sat Apr 13 18:33:12 2024 -0700

    [Runtime] Implemented Datatype.itemsize() (#16880)
    
    * [Runtime] Implemented Datatype.itemsize()
---
 python/tvm/_ffi/runtime_ctypes.py       | 14 ++++++++++++
 python/tvm/dlight/gpu/gemv.py           |  2 +-
 python/tvm/dlight/gpu/low_batch_gemv.py |  8 +++----
 tests/python/ir/test_dtype.py           | 40 +++++++++++++++++++++++++++++++++
 4 files changed, 58 insertions(+), 6 deletions(-)

diff --git a/python/tvm/_ffi/runtime_ctypes.py 
b/python/tvm/_ffi/runtime_ctypes.py
index dc5582d045..099cbe972a 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -212,6 +212,20 @@ class DataType(ctypes.Structure):
     def __ne__(self, other):
         return not self.__eq__(other)
 
+    def itemsize(self):
+        """Get the number of bytes of a single element of this data type. When 
the number of lanes
+        is greater than 1, the itemsize is the size of the vector type.
+
+        Returns
+        -------
+        itemsize : int
+            The number of bytes of a single element of this data type
+        """
+        lanes_as_int = ctypes.c_int16(self.lanes).value
+        if lanes_as_int < 0:
+            raise ValueError("Cannot determine itemsize for scalable vector 
types")
+        return (self.bits * self.lanes + 7) // 8
+
 
 if ml_dtypes is not None:
     DataType.NUMPY2STR[np.dtype(ml_dtypes.bfloat16)] = "bfloat16"
diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index c1ce876620..644f4e6dfa 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -57,7 +57,7 @@ def get_extent(sch: tir.Schedule, loop_rv: 
tir.schedule.LoopRV):
 def get_bytes(dtype: Union[DataType, str]) -> int:
     if isinstance(dtype, str):
         dtype = DataType(dtype)
-    return dtype.bits * dtype.lanes // 8
+    return dtype.itemsize()
 
 
 def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> 
Optional[List[tir.Buffer]]:
diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py 
b/python/tvm/dlight/gpu/low_batch_gemv.py
index 9a92c9e0e9..696722c3f0 100644
--- a/python/tvm/dlight/gpu/low_batch_gemv.py
+++ b/python/tvm/dlight/gpu/low_batch_gemv.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 """A rule for low-batch GEMM / decode-GEMM using GEMV schedule."""
-import re
 from functools import reduce
 from typing import List, Optional, Set, Union
 
@@ -55,10 +54,9 @@ def get_extent(sch: tir.Schedule, loop_rv: 
tir.schedule.LoopRV):
 
 
 def get_bytes(dtype: Union[DataType, str]) -> int:
-    num = re.findall(r"\d+", dtype)
-    if len(num) != 1:
-        raise ValueError(f"Cannot get bytes from {dtype}")
-    return int(num[0]) // 8
+    if isinstance(dtype, str):
+        dtype = DataType(dtype)
+    return dtype.itemsize()
 
 
 def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> 
Optional[List[tir.Buffer]]:
diff --git a/tests/python/ir/test_dtype.py b/tests/python/ir/test_dtype.py
new file mode 100644
index 0000000000..77cd1d7e4b
--- /dev/null
+++ b/tests/python/ir/test_dtype.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test data type related API"""
+import tvm
+from tvm import DataType
+import tvm.testing
+import pytest
+
+
+@pytest.mark.parametrize(
+    "dtype_str, expected_size",
+    [("float32", 4), ("float32x4", 16), ("e5m2_float8x4", 4), ("uint8", 1)],
+)
+def test_dtype_itemsize(dtype_str, expected_size):
+    dtype = DataType(dtype_str)
+    assert dtype.itemsize() == expected_size
+
+
+@pytest.mark.parametrize("dtype_str", [("int32xvscalex4")])
+def test_dtype_itemmize_error(dtype_str):
+    with pytest.raises(ValueError):
+        size = DataType(dtype_str).itemsize()
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to