junrushao1994 commented on a change in pull request #10793:
URL: https://github.com/apache/tvm/pull/10793#discussion_r838937818



##########
File path: tests/python/unittest/test_meta_schedule_tune_relay.py
##########
@@ -323,6 +325,222 @@ def get_output(data, lib):
         assert np.allclose(actual_output, expected_output, rtol=1e-4, 
atol=2e-4)
 
 
+@T.prim_func
+def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+        for i in T.serial(0, 16):
+            with T.init():
+                C[i] = T.int32(0)
+            for k in T.serial(0, 4):
+                with T.block("update"):
+                    vi, vk = T.axis.remap("SR", [i, k])
+                    C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], 
"int32")
+
+
+@T.prim_func
+def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
+    B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
+    C = T.match_buffer(c, (16,), "int32", offset_factor=1)
+
+    with T.block("root"):
+        T.reads(C[0:16], A[0:4], B[0:16, 0:4])
+        T.writes(C[0:16])
+
+        A_u8x4 = A.vload([0], "uint8x4")
+        A_i32 = T.reinterpret(A_u8x4, dtype="int32")
+
+        B_i8x64 = B.vload([0, 0], dtype="int8x64")
+        B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
+
+        C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(  # Note: this 
is an update +=
+            T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
+            T.uint32(0),
+            T.int32x16(0),
+            T.broadcast(A_i32, 16),
+            B_i32x16,
+            dtype="int32x16",
+        )
+
+
+VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
+
+
+def schedule_dense(block, M, do_tune, sch):
+    post_blocks = sch.get_consumers(block)
+
+    if len(post_blocks) > 0:
+        while True:
+            next_post_blocks = []
+            for post_block in post_blocks:
+                next_consumers = sch.get_consumers(post_block)
+
+                if len(next_consumers) > 0:
+                    sch.compute_inline(post_block)
+
+                next_post_blocks += next_consumers
+
+            if len(next_post_blocks) == 0:
+                assert len(post_blocks) == 1
+                outer_block = post_blocks[0]
+                a_y, a_x = sch.get_loops(outer_block)[-2:]
+                break
+
+            post_blocks = next_post_blocks
+    else:
+        a_y, a_x, _ = sch.get_loops(block)[-3:]
+        outer_block = block
+
+    if do_tune:
+        y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128)
+        a_yo, a_yi = sch.split(a_y, factors=y_factors)
+    else:
+        a_yo, a_yi = sch.split(a_y, factors=[None, min(M, 64)])
+
+    a_xo, a_xi = sch.split(a_x, factors=[None, 16])
+    sch.reorder(a_yo, a_xo, a_yi, a_xi)
+    fused = sch.fuse(a_yo, a_xo)
+
+    if outer_block != block:
+        sch.vectorize(a_xi)
+        sch.compute_at(block, a_yi)
+
+    a_xi, a_k = sch.get_loops(block)[-2:]
+    a_ko, a_ki = sch.split(a_k, factors=[None, 4])
+    sch.reorder(a_ko, a_xi, a_ki)
+
+    sch.parallel(fused)
+    dec = sch.decompose_reduction(block, a_ko)
+
+    init_loop = sch.get_loops(dec)[-1]
+    sch.vectorize(init_loop)
+
+    sch.tensorize(a_xi, VNNI_INTRIN)
+
+
+def manual_tir_common(do_tune=False):
+    M, N, K = 1024, 1024, 1024
+    data_shape = (M, K)
+    weight_shape = (N, K)
+
+    data_dtype = "uint8"
+    data = relay.var("data", shape=data_shape, dtype=data_dtype)
+    weight = relay.var("weight", shape=weight_shape, dtype="int8")
+    bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+
+    # dense is tuned by the TIR schedule above, bmm is scheduled by TE 
(topi/x86/batch_matmul.py)
+    dense = relay.nn.dense(data, weight, out_dtype="int32")
+    bias_add = relay.nn.bias_add(dense, bias) + relay.const(1, dtype="int32")
+    out = relay.nn.batch_matmul(
+        relay.cast(relay.expand_dims(bias_add, 0), "uint8"),
+        relay.cast(relay.expand_dims(bias_add, 0), "int8"),
+        out_dtype="int32",
+    )
+
+    relay_mod = tvm.IRModule.from_expr(out)
+
+    target = "llvm -mcpu=cascadelake -num-cores 4"
+    dev = tvm.device(target, 0)
+
+    data = np.random.uniform(1, 10, size=(M, K)).astype("uint8")
+    weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+    bias_np = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+    ref = (
+        relay.create_executor("vm", mod=relay_mod, device=dev, target=target)
+        .evaluate()(*[data, weight_np, bias_np])
+        .numpy()
+    )
+
+    params = {"weight": weight_np, "bias": bias_np}
+
+    extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+    tune_tasks = list(
+        filter(
+            lambda task: "dense" in task.task_name,
+            extracted_tasks,
+        )
+    )
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        if do_tune:
+            config = ReplayTraceConfig(
+                num_trials_per_iter=64,
+                num_trials_total=64,
+            )
+            database = tune_extracted_tasks(
+                tune_tasks, target, config, work_dir=work_dir, 
postprocs=lambda: []
+            )
+        else:
+            database = JSONDatabase(
+                path_workload=osp.join(work_dir, "database_workload.json"),
+                path_tuning_record=osp.join(work_dir, 
"database_tuning_record.json"),
+            )
+
+            for task in tune_tasks:
+                mod = Parse._mod(task.dispatched[0])
+                workload = database.commit_workload(mod)
+
+                sch = tvm.tir.Schedule(mod)
+                block = sch.get_block("compute")
+                schedule_rule = sch.get(block).annotations["schedule_rule"]
+
+                if "dense_vnni" in schedule_rule:
+                    schedule_dense(block, M, False, sch)
+
+                tune_rec = TuningRecord(sch.trace, [0.0], workload, 
tvm.target.Target(target), [])
+
+                database.commit_tuning_record(tune_rec)
+
+    with ApplyHistoryBest(database):
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={"relay.backend.use_meta_schedule": True},
+        ):
+            """
+            The log should say
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_expand_dims
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_cast_1
+            meta_schedule/integration.cc:146: Warning: Cannot find workload: 
tvmgen_default_fused_nn_batch_matmul
+
+            This means batch matmul and others are scheduled by TE, and dense 
(the one not warned) is found in the
+            meta schedule tuning database during ApplyHistoryBest
+            """
+            lib = relay.build(relay_mod, target=target, params=params)
+
+    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+    runtime.set_input("data", data)
+    runtime.run()
+
+    out = runtime.get_output(0).numpy()
+
+    np.testing.assert_equal(out, ref)
+
+
+@pytest.mark.skip("Requires cascadelake")
+def test_tune_relay_manual_tir_vnni():
+    tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, 
dot_product_intrin)
+
+    manual_tir_common(do_tune=False)
+
+    def schedule_rule_dense_vnni(sch, block):
+        schedule_dense(block, None, True, sch)
+        return [sch]
+
+    register_func("meta_schedule.dense_vnni", schedule_rule_dense_vnni)

Review comment:
       @masahi Thanks for the comprehensive discussion about API design and 
user intention!
   
   Looks like we have 3 different proposals:
   - P1. Always annotate schedule_rule even if it's not registered; Print a 
warning if the schedule_rule isn't registered;
   - P2. Always make sure schedule_rule exists when annotating, and error out 
if it's not;
   - P3. Always annotate schedule_rule even if it's not registered; No warning 
if the schedule_rule isn't registered;
   
   Masa and I both agree that P3 may not be ideal, because silently ignoring 
abnormality may not be the best user experience.
   
   > In the future when auto-tensorization is ready, I want to freely switch 
between manual and automatic scheduling.
   
   Totally agree about future possibility of switching around! Definitely it's 
going to be a lot of fun :-)
   
   Therefore, I would conclude that emitting a warning is probably both @masahi 
and I agree on. Additionally, we might enhance the search space generation to 
selectively check if the schedule_rule is allowed with a target-specific 
allowlist, but it's probably not high-priority for now.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to