This is an automated email from the ASF dual-hosted git repository.

ziheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dc1f189  [AutoTVM] [TOPI] Support AutoTVM for int4 tensorcore (#7831)
dc1f189 is described below

commit dc1f189207dd78f576d9da7c61f124147a488c42
Author: Andrew Liu <andrewl...@gmail.com>
AuthorDate: Sat May 1 01:27:36 2021 -0700

    [AutoTVM] [TOPI] Support AutoTVM for int4 tensorcore (#7831)
    
    * initial
    
    * int4 asnumpy
    
    * remove
    
    * random test
    
    * format
    
    * random
    
    * remove unused import
    
    * change dist range
    
    * add fuse_pack in
    
    * random engine
    
    * reformat
    
    * remove import
    
    * add cuda context
    
    * refactor code
---
 python/tvm/autotvm/measure/measure_methods.py      |  3 +-
 python/tvm/runtime/ndarray.py                      | 12 +++++
 python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py     | 37 +++-----------
 src/runtime/contrib/random/mt_random_engine.cc     |  6 +++
 src/target/source/codegen_cuda.cc                  | 14 ++++++
 tests/python/contrib/test_random.py                |  7 ++-
 .../python/test_topi_conv2d_hwnc_tensorcore.py     | 56 +++++++++++++++++++++-
 tests/python/unittest/test_target_codegen_cuda.py  | 22 +++++++++
 8 files changed, 119 insertions(+), 38 deletions(-)

diff --git a/python/tvm/autotvm/measure/measure_methods.py 
b/python/tvm/autotvm/measure/measure_methods.py
index aa072cf..6d01140 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -32,7 +32,6 @@ import typing
 from random import getrandbits
 from collections import namedtuple
 import tempfile
-import numpy as np
 
 import tvm._ffi
 import tvm.ir.transform
@@ -583,7 +582,7 @@ def run_through_rpc(
                 raise AttributeError(
                     "Please make sure USE_RANDOM is ON in the config.cmake " 
"on the remote devices"
                 )
-            args = [nd.array(np.zeros(x[0], dtype=x[1]), device=dev) for x in 
build_result.arg_info]
+            args = [nd.empty(x[0], x[1], dev) for x in build_result.arg_info]
             if "scatter" not in measure_input.task.name:
                 # the index tensor of scatter op cannot be randomly initialized
                 for arg in args:
diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py
index 980f70d..1b0f130 100644
--- a/python/tvm/runtime/ndarray.py
+++ b/python/tvm/runtime/ndarray.py
@@ -170,15 +170,27 @@ class NDArray(NDArrayBase):
         """
         t = DataType(self.dtype)
         shape, dtype = self.shape, self.dtype
+        old_dtype = dtype
         if t.lanes > 1:
             shape = shape + (t.lanes,)
             t.lanes = 1
             dtype = str(t)
+        if dtype == "int4":
+            dtype = "int8"
         np_arr = np.empty(shape, dtype=dtype)
         assert np_arr.flags["C_CONTIGUOUS"]
         data = np_arr.ctypes.data_as(ctypes.c_void_p)
         nbytes = ctypes.c_size_t(np_arr.size * np_arr.dtype.itemsize)
         check_call(_LIB.TVMArrayCopyToBytes(self.handle, data, nbytes))
+        if old_dtype == "int4":
+            length = np_arr.size
+            np_arr_ret = np.empty((length,), dtype="int8")
+            np_arr = np_arr.reshape((length,))
+            old_index = np.bitwise_and(np_arr, 0x0F)
+            even_index = np.bitwise_and(np_arr >> 4, 0x0F)
+            np_arr_ret[1::2] = old_index[0 : length // 2]
+            np_arr_ret[0::2] = even_index[0 : length // 2]
+            return np_arr_ret.reshape(shape)
         return np_arr
 
     def copyto(self, target):
diff --git a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py 
b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py
index e2d3cd9..b3d8397 100644
--- a/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py
+++ b/python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py
@@ -184,9 +184,9 @@ def hwnc_tensorcore_cuda(cfg, Input, Filter, stride, 
padding, dilation, out_dtyp
 
 def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     """Schedule tensorcore template"""
-    packed_data, packed_kernel = s[Conv].op.input_tensors
+    pad_data, packed_kernel = s[Conv].op.input_tensors
     ic, kh, kw, ii = s[Conv].op.reduce_axis
-    pad_data = s[packed_data].op.input_tensors[0]
+    packed_data = s[pad_data].op.input_tensors[0]
 
     block_x = te.thread_axis("blockIdx.x")
     block_y = te.thread_axis("blockIdx.y")
@@ -196,7 +196,7 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     thread_z = te.thread_axis("threadIdx.z")
 
     # Designate the memory hierarchy
-    AS = s.cache_read(packed_data, "shared", [Conv])
+    AS = s.cache_read(pad_data, "shared", [Conv])
     WS = s.cache_read(packed_kernel, "shared", [Conv])
     AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
     WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
@@ -241,7 +241,6 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16])
     cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16])
     cfg.define_knob("chunk", [1, 2, 4, 8])
-    cfg.define_knob("fuse_pack", [0, 1])
     cfg.define_knob("split_block_k_nums", [1, 2, 4, 8, 16, 32])
     cfg.define_knob("vector_ws", [1, 8])
     cfg.define_knob("vector_as", [1, 8, 16])
@@ -254,13 +253,8 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     vector_as = cfg["vector_as"].val
     vector_ws = cfg["vector_ws"].val
     split_block_k_nums = cfg["split_block_k_nums"].val
-    fuse_pack = cfg["fuse_pack"].val
 
-    if not fuse_pack:
-        s[packed_data].compute_inline()
-    else:
-        with Target("cuda"):
-            schedule_injective_from_existing(s, packed_data)
+    s[packed_data].compute_inline()
 
     if data_dtype in ["int4", "uint4"]:
         wmma_m = wmma_n = 8
@@ -324,24 +318,13 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     cfg["reorder_inner"].apply(s, ConvF, [ko, kh])
     cfg["reorder_inner"].apply(s, ConvF, [ki, kw])
 
-    cfg.define_knob("compute_at_AS", [0, 1, 2, 3])
-    cfg.define_knob("compute_at_WS", [0, 1, 2, 3])
-    compute_at_AS = cfg["compute_at_AS"].val
-    compute_at_WS = cfg["compute_at_WS"].val
-
     # Move intermediate computation into each output compute tile
     s[AF].compute_at(s[ConvF], kw)
     s[WF].compute_at(s[ConvF], kw)
 
     # Schedule for A's share memory
-    if compute_at_AS == 0:
-        s[AS].compute_at(s[ConvF], ki)
-    elif compute_at_AS == 1:
-        s[AS].compute_at(s[ConvF], kw)
-    elif compute_at_AS == 2:
-        s[AS].compute_at(s[ConvF], ko)
-    else:
-        s[AS].compute_at(s[ConvF], kh)
+    s[AS].compute_at(s[ConvF], ko)
+
     _, _, n, _, nn, ii = AS.op.axis
     tx, xo = s[AS].split(n, nparts=block_row_warps)
     ty, _ = s[AS].split(xo, nparts=block_col_warps)
@@ -354,14 +337,6 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
     s[AS].vectorize(_t)
 
     # Schedule for W's share memory
-    if compute_at_WS == 0:
-        s[WS].compute_at(s[ConvF], ki)
-    elif compute_at_WS == 1:
-        s[WS].compute_at(s[ConvF], kw)
-    elif compute_at_WS == 2:
-        s[WS].compute_at(s[ConvF], ko)
-    else:
-        s[WS].compute_at(s[ConvF], kh)
     s[WS].compute_at(s[ConvF], kw)
     kh, kw, ic, o, ii, oo = WS.op.axis
     tx, xo = s[WS].split(o, nparts=block_row_warps)
diff --git a/src/runtime/contrib/random/mt_random_engine.cc 
b/src/runtime/contrib/random/mt_random_engine.cc
index a1c6dc2..161ae62 100644
--- a/src/runtime/contrib/random/mt_random_engine.cc
+++ b/src/runtime/contrib/random/mt_random_engine.cc
@@ -140,6 +140,12 @@ class RandomEngine {
     // Use float representation could make us work well on float / int type 
too.
     if (tensor->dtype.bits == 1) {
       std::generate_n(static_cast<bool*>(tensor->data), size, [&]() { return 
dist(rnd_engine_); });
+    } else if (tensor->dtype.bits == 4) {
+      // For uint4/int4 we pack two values into a single byte.
+      // Thus, to ensure both values are non-zero, we use a distribution of 17 
- 30.
+      std::uniform_real_distribution<> packed_dist(17.0, 30.0);
+      std::generate_n(reinterpret_cast<uint8_t*>(tensor->data), size,
+                      [&]() { return packed_dist(rnd_engine_); });
     } else if (tensor->dtype.bits == 8) {
       std::generate_n(static_cast<uint8_t*>(tensor->data), size,
                       [&]() { return dist(rnd_engine_); });
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index cd02dcc..4cc999b 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -809,6 +809,20 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, 
std::ostream& os) {  // NO
     return;
   }
 
+  if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4 && 
op->lanes == 8) {
+    // make_int4x8
+    const int64_t* p = as_const_int(op->value);
+    ICHECK(p);
+    int64_t v = *p & 0xF;
+    v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | 
(v << 4) | v;
+    if (op->dtype.is_uint()) {
+      os << "(uint)" << v;
+    } else {
+      os << "(int)" << v;
+    }
+    return;
+  }
+
   std::string v = PrintExpr(op->value);
   os << "make_";
   PrintType(op->dtype, os);
diff --git a/tests/python/contrib/test_random.py 
b/tests/python/contrib/test_random.py
index 0740521..0ebf255 100644
--- a/tests/python/contrib/test_random.py
+++ b/tests/python/contrib/test_random.py
@@ -102,8 +102,7 @@ def test_random_fill():
         if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
             print("skip because extern function is not available")
             return
-        np_ones = np.ones((512, 512), dtype=dtype)
-        value = tvm.nd.empty(np_ones.shape, np_ones.dtype, dev)
+        value = tvm.nd.empty((512, 512), dtype, dev)
         random_fill = tvm.get_global_func("tvm.contrib.random.random_fill")
         random_fill(value)
 
@@ -119,10 +118,9 @@ def test_random_fill():
             return
         if not tvm.testing.device_enabled("rpc") or not 
tvm.runtime.enabled("llvm"):
             return
-        np_ones = np.ones((512, 512), dtype=dtype)
         server = rpc.Server("localhost")
         remote = rpc.connect(server.host, server.port)
-        value = tvm.nd.empty(np_ones.shape, np_ones.dtype, remote.cpu())
+        value = tvm.nd.empty((512, 512), dtype, remote.cpu())
         random_fill = remote.get_function("tvm.contrib.random.random_fill")
         random_fill(value)
 
@@ -134,6 +132,7 @@ def test_random_fill():
 
     for dtype in [
         "bool",
+        "int4",
         "int8",
         "uint8",
         "int16",
diff --git a/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py 
b/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py
index bb11a56..1b35fe8 100644
--- a/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py
+++ b/tests/python/topi/python/test_topi_conv2d_hwnc_tensorcore.py
@@ -22,7 +22,7 @@ import tvm
 import os
 import tvm.testing
 import tvm.topi.testing
-from tvm import te, autotvm, topi
+from tvm import te, autotvm, topi, relay
 from tvm.contrib.pickle_memoize import memoize
 from tvm.contrib import nvcc
 from tvm.topi.nn.utils import get_pad_tuple
@@ -136,6 +136,59 @@ def verify_conv2d_hwnc(
     check_target("cuda")
 
 
+def verify_feature_length():
+    np.random.seed(123)
+    target = "cuda"
+    ctx = tvm.device(target)
+
+    batch_size = 32
+
+    input_shape = (32, 512, 7, 7)
+    kernel_shape = (512, 512, 3, 3)
+
+    def get_mod():
+        x = relay.var("x", relay.TensorType(input_shape, "float32"))
+        y = relay.var("y", relay.TensorType(kernel_shape, "float32"))
+        f = relay.Function(
+            [x, y], relay.nn.conv2d(x, y, padding=[1, 1, 1, 1], channels=512, 
kernel_size=[3, 3])
+        )
+        mod = tvm.IRModule()
+        mod["main"] = f
+        mod = relay.transform.InferType()(mod)
+        return mod, {}
+
+    mod, params = get_mod()
+    layout_config = relay.transform.LayoutConfig()
+    desired_layouts = {"nn.conv2d": ["HWNC", "default"]}
+    with layout_config:
+        seq = 
tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
+        with tvm.transform.PassContext(opt_level=3):
+            mod = seq(mod)
+    mod = relay.transform.recast(mod, "int4", "int32")
+
+    tasks = autotvm.task.extract_from_program(
+        mod, target=target, params=params, ops=(relay.op.get("nn.conv2d"),)
+    )
+
+    assert len(tasks) == 1
+    task = tasks[0]
+
+    space = task.config_space
+
+    idx1 = np.random.randint(len(space))
+    idx2 = np.random.randint(len(space))
+
+    cfg = space.get(idx1)
+    sch, arg_bufs = task.instantiate(cfg)
+    fea1 = autotvm.feature.get_itervar_feature_flatten(sch, arg_bufs, 
take_log=True)
+
+    cfg = space.get(idx2)
+    sch, arg_bufs = task.instantiate(cfg)
+    fea2 = autotvm.feature.get_itervar_feature_flatten(sch, arg_bufs, 
take_log=True)
+
+    assert len(fea1) == len(fea2)
+
+
 @tvm.testing.requires_tensorcore
 def test_conv2d_hwnc_tensorcore():
     """Test the conv2d with tensorcore for hwnc layout"""
@@ -150,6 +203,7 @@ def test_conv2d_hwnc_tensorcore():
     verify_conv2d_hwnc(8, 256, 14, 512, 3, 2, 1)
     verify_conv2d_hwnc(8, 256, 14, 512, 1, 2, 0)
     verify_conv2d_hwnc(8, 512, 9, 512, 3, 1, 1)
+    verify_feature_length()
 
 
 if __name__ == "__main__":
diff --git a/tests/python/unittest/test_target_codegen_cuda.py 
b/tests/python/unittest/test_target_codegen_cuda.py
index a63aeaa..e639e6b 100644
--- a/tests/python/unittest/test_target_codegen_cuda.py
+++ b/tests/python/unittest/test_target_codegen_cuda.py
@@ -206,6 +206,27 @@ def test_cuda_make_int8():
 
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
+def test_cuda_make_int4():
+    def check_cuda(n, value, lanes):
+        dtype = "int4"
+        dev = tvm.gpu(0)
+        A = te.compute((n, lanes), lambda i, j: tvm.tir.const(value, 
dtype=dtype))
+        s = te.create_schedule(A.op)
+        y, x = s[A].op.axis
+        s[A].vectorize(x)
+        s[A].bind(y, bx)
+        fun = tvm.build(s, [A], "cuda", name="make_int4x8")
+        np_a = np.full((n, lanes), value, dtype="int8")
+        a = tvm.nd.empty((n, lanes), dtype, dev)
+        fun(a)
+        np.testing.assert_equal(a.asnumpy(), np_a)
+
+    check_cuda(64, 1, 8)
+    check_cuda(64, 7, 8)
+
+
+@tvm.testing.requires_gpu
+@tvm.testing.requires_cuda
 def test_cuda_inf_nan():
     target = "cuda"
 
@@ -972,6 +993,7 @@ if __name__ == "__main__":
     test_cuda_bf16_vectorize_add()
     test_cuda_multiply_add()
     test_cuda_vectorize_load()
+    test_cuda_make_int4()
     test_cuda_make_int8()
     test_cuda_inf_nan()
     test_cuda_shuffle()

Reply via email to