tqchen commented on code in PR #15141:
URL: https://github.com/apache/tvm/pull/15141#discussion_r1238660532
##########
main.py:
##########
@@ -0,0 +1,267 @@
+# pylint: disable=missing-docstring
+from typing import List, Optional
+
+from tvm import IRModule
+from tvm import meta_schedule as ms
+from tvm import tir
+from tvm.dlight import ScheduleGenerator, ScheduleRule, auto_inline_consumers
+
+
+class Decode(ScheduleRule):
+ def __init__(self):
+ ...
+
+ def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
+ pass
+
+ def clone(self) -> ScheduleRule:
+ return Decode()
+
+ def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint:
disable=too-many-locals
+ sch = tir.Schedule(mod)
+ try:
+ decode = sch.get_block("decode")
+ except: # pylint: disable=bare-except
+ return None
+ len_tx: int = 8
+ len_ty: int = 8
+ len_yi: int = 1
+ len_yc: int = 8
+
+ # Step 1. Tile the decoding
+ i, j = sch.get_loops(decode)
+ by, ty, yi, yc = sch.split( # pylint: disable=invalid-name
+ i, factors=[None, len_ty, len_yi, len_yc]
+ )
+ bx, tx = sch.split(j, factors=[None, len_tx]) # pylint:
disable=invalid-name
+ sch.reorder(by, bx, ty, tx, yi, yc)
+ sch.bind(by, "blockIdx.y")
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.unroll(yc)
+ # Step 2. Cache results in shared memory
+ rb = sch.cache_write(decode, 0, "shared") # pylint:
disable=invalid-name
+ consumers = sch.get_consumers(rb)
+ if consumers:
+ (consumer,) = consumers
+ auto_inline_consumers(sch, consumer)
+ sch.compute_inline(rb)
+ rb = consumer # pylint: disable=invalid-name
+ # Step 3. Schedule the shared memory write back
+ sch.reverse_compute_at(rb, bx, preserve_unit_loops=True)
+ loop = sch.fuse(*sch.get_loops(rb)[-2:])
+ _, ty, tx, vec = sch.split( # pylint: disable=invalid-name
+ loop, factors=[None, len_ty, len_tx, 4]
+ )
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+ sch.storage_align(decode, buffer_index=0, axis=0, factor=32, offset=1)
+ return sch
+
+
+class DecodeGemv(ScheduleRule):
+ def __init__(self):
+ ...
+
+ def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
+ pass
+
+ def clone(self) -> ScheduleRule:
+ return DecodeGemv()
+
+ def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint:
disable=too-many-locals
+ sch = tir.Schedule(mod)
+ try:
+ gemv = sch.get_block("matmul")
+ decode = sch.get_block("decode")
+ except: # pylint: disable=bare-except
+ return None
+ len_vx: int = 2
+ len_tx: int = 64
+ len_km: int = 2
+ len_ki: int = 1 * 8
+ # Step 1. Schedule GEMV
+ # [b=1, i=1, j, k]
+ # split j => [b=1, i=1, (bx, vx, tx), k]
+ # fuse (b, i, bx) => [bx, vx, tx, (k)]
+ # split k => [bx, vx, tx, (ko, k_m, ki * 8)]
+ rb = sch.cache_write(gemv, 0, "local") # pylint: disable=invalid-name
+ b, i, j, k = sch.get_loops(gemv) # pylint: disable=invalid-name
+ assert sch.get(b).extent.value == 1
+ assert sch.get(i).extent.value == 1
+ bx, vx, tx = sch.split(j, [None, len_vx, len_tx]) # pylint:
disable=invalid-name
+ bx = sch.fuse(b, i, bx) # pylint: disable=invalid-name
+ k_o, k_m, k_i = sch.split(k, [None, len_km, len_ki])
+ sch.bind(bx, thread_axis="blockIdx.x")
+ sch.bind(vx, thread_axis="vthread.x")
+ sch.bind(tx, thread_axis="threadIdx.x")
+ sch.reorder(bx, vx, tx, k_o, k_m, k_i)
+ sch.unroll(k_i)
+ # Step 2. Schedule decode: move to under threadIdx.x and fetch
separately for each thread
+ sch.compute_at(decode, k_m, preserve_unit_loops=True)
+ sch.set_scope(decode, 0, "local")
+ _, unroll = sch.split(sch.get_loops(decode)[-2], [None, 8])
+ sch.unroll(unroll)
+
+ # Step 3. Cooperative fetch GEMV
+ def cooperative_fetch(block, tx): # pylint: disable=invalid-name
+ block = sch.cache_read(block, 0, "shared")
+ sch.compute_at(block, tx, preserve_unit_loops=True)
+ loop = sch.fuse(*sch.get_loops(block)[-2:])
+ len_vector = sch.sample_categorical(
+ [1, 2, 3, 4],
+ probs=[0.25, 0.25, 0.25, 0.25],
+ )
+ _, tx, vec = sch.split(loop, [None, len_tx, len_vector])
+ sch.bind(tx, thread_axis="threadIdx.x")
+ sch.vectorize(vec)
+ sch.storage_align(block, buffer_index=0, axis=-2, factor=32,
offset=8)
+
+ cooperative_fetch(gemv, k_o)
+ # Step 4. Schedule epilogue
+ auto_inline_consumers(sch, rb)
+ sch.reverse_compute_at(rb, tx, preserve_unit_loops=True)
+ # Step 5. Postprocess: decompose reduction
+ sch.decompose_reduction(gemv, k_o)
+ return [sch]
+
+
+class Normalization(ScheduleRule):
+ def __init__(self):
+ ...
+
+ def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
+ pass
+
+ def clone(self) -> ScheduleRule:
+ return Normalization()
+
+ def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint:
disable=too-many-locals
+ sch = tir.Schedule(mod)
+ b_reduce = None
+ for name in ["Ared_temp", "A_red_temp"]:
+ try:
+ b_reduce = sch.get_block(name)
+ except:
+ continue
+ else:
+ break
+ if b_reduce is None:
+ return None
+ len_tx: int = 256
+ unroll_depth: int = 256
+
+ (b_spatial,) = sch.get_consumers(b_reduce)
+ loops = sch.get_loops(b_spatial)
+ bx = sch.fuse(*loops[:-1]) # pylint: disable=invalid-name
+ _, tx = sch.split(loops[-1], [None, len_tx]) # pylint:
disable=invalid-name
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(tx, "threadIdx.x")
+
+ for i, _ in enumerate(sch.get(b_reduce).writes):
+ sch.set_scope(b_reduce, buffer_index=i, storage_scope="shared")
+ sch.compute_at(b_reduce, bx, preserve_unit_loops=True)
+ _, tx = sch.split( # pylint: disable=invalid-name
+ sch.get_loops(b_reduce)[-1],
+ [None, len_tx],
+ )
+ sch.bind(tx, "threadIdx.x")
+ auto_inline_consumers(sch, b_spatial)
+ sch.annotate(bx, ann_key="pragma_auto_unroll_max_step",
ann_val=unroll_depth)
+ sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1)
+
+ return sch
+
+
+class Softmax(ScheduleRule):
+ def __init__(self):
+ ...
+
+ def _initialize_with_tune_context(self, context: ms.TuneContext) -> None:
+ pass
+
+ def clone(self) -> ScheduleRule:
+ return Softmax()
+
+ def apply(self, mod: IRModule) -> Optional[tir.Schedule]: # pylint:
disable=too-many-locals
+ sch = tir.Schedule(mod)
+ try:
+ b_reduce_0 = sch.get_block("T_softmax_maxelem")
+ b_reduce_1 = sch.get_block("T_softmax_expsum")
+ b_spatial = sch.get_block("T_softmax_norm")
+ except: # pylint: disable=bare-except
+ return None
+
+ len_tx: int = 256
+ unroll_depth: int = 256
+
+ sch.compute_inline(sch.get_producers(b_reduce_1)[0])
+
+ loops = sch.get_loops(b_spatial)
+ bx = sch.fuse(*loops[:-1]) # pylint: disable=invalid-name
+ _, tx = sch.split(loops[-1], [None, len_tx]) # pylint:
disable=invalid-name
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(tx, "threadIdx.x")
+
+ sch.set_scope(b_reduce_1, buffer_index=0, storage_scope="shared")
+ sch.compute_at(b_reduce_1, bx, preserve_unit_loops=True)
+ _, tx = sch.split( # pylint: disable=invalid-name
+ sch.get_loops(b_reduce_1)[-1],
+ [None, len_tx],
+ )
+ sch.bind(tx, "threadIdx.x")
+
+ sch.set_scope(b_reduce_0, buffer_index=0, storage_scope="shared")
+ sch.compute_at(b_reduce_0, bx, preserve_unit_loops=True)
+ _, tx = sch.split( # pylint: disable=invalid-name
+ sch.get_loops(b_reduce_0)[-1],
+ [None, len_tx],
+ )
+ sch.bind(tx, "threadIdx.x")
+
+ sch.annotate(bx, ann_key="pragma_auto_unroll_max_step",
ann_val=unroll_depth)
+ sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1)
+
+ auto_inline_consumers(sch, b_spatial)
+ return [sch]
+
+
+def main():
+ from tvm.dlight.testing import mod_decode, mod_decode_gemv, mod_norm,
mod_softmax
+
+ gen = ScheduleGenerator(
+ rules=[
+ DecodeGemv(),
+ Decode(),
+ Normalization(),
+ Softmax(),
+ ]
+ )
+
+ for py_mod in [
+ mod_decode,
+ mod_decode_gemv,
+ mod_norm,
+ mod_softmax, # Needs to upstream `compute-inline`
+ ]:
+ i = 1
+ while True:
+ try:
+ func = py_mod.Module[f"func{i}"]
+ except: # pylint: disable=bare-except
+ break
+ else:
+ print(f"Working on {py_mod}::func{i}")
+ i += 1
+ mod = IRModule.from_expr(func.with_attr("global_symbol", "main"))
+ schedules = gen.generate_design_space(mod)
Review Comment:
Let us structure the testcase UT same as Before/After test, e.g. one softmax
function before and after with structural equality test. Likely most of the
dlight testing can move to `tests/dlight/` (that can be added to CI)
Atm we only test that the schedule exists
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]