tqchen commented on code in PR #15141: URL: https://github.com/apache/tvm/pull/15141#discussion_r1240021306
########## tests/python/dlight/test_modules.py: ########## @@ -0,0 +1,2621 @@ +from tvm.script import ir as I +from tvm.script import tir as T + +# fmt: off + [email protected]_module +class Decode: + + @T.prim_func + def func1(A: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(80), T.int64(2560)), "float16"), T_transpose: T.Buffer((T.int64(2560), T.int64(2560)), "float16")): + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def func2(A: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), B: T.Buffer((T.int64(80), T.int64(10240)), "float16"), T_transpose: T.Buffer((T.int64(10240), T.int64(2560)), "float16")): + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") + for i, j in T.grid(T.int64(2560), T.int64(10240)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(10240), T.int64(2560)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def func3(A: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), B: T.Buffer((T.int64(320), T.int64(2560)), "float16"), T_transpose: T.Buffer((T.int64(2560), T.int64(10240)), "float16")): + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def func4(A: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(128), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def func5(A: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), B: T.Buffer((T.int64(128), T.int64(11008)), "float16"), T_transpose: T.Buffer((T.int64(11008), T.int64(4096)), "float16")): + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + @T.prim_func + def func6(A: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(344), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(11008)), "float16")): + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + # with T.block("root"): + decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(11008), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j]) + T.writes(decode[v_i, v_j]) + decode[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // T.int64(32), v_j] + for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)): + with T.block("T_transpose"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(decode[v_ax1, v_ax0]) + T.writes(T_transpose[v_ax0, v_ax1]) + T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0] + + [email protected]_module +class DecodeGemv: + + @T.prim_func + def func1(lv2515: T.Buffer((T.int64(320), T.int64(50432)), "uint32"), lv2516: T.Buffer((T.int64(80), T.int64(50432)), "float32"), lv705: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(50432))) + for i, j in T.grid(T.int64(2560), T.int64(50432)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv2515[v_i // T.int64(8), v_j], lv2516[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = T.Cast("float32", T.Cast("float16", T.bitwise_and(T.shift_right(lv2515[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv2516[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(50432), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv705[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv705[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + + @T.prim_func + def func2(lv1363: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv1364: T.Buffer((T.int64(80), T.int64(2560)), "float16"), lv2067: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias192: T.Buffer((T.int64(2560),), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1363[v_i // T.int64(8), v_j], lv1364[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1363[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1364[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2067[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2067[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias192[v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias192[v_ax2] + + @T.prim_func + def func3(lv1381: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), lv1382: T.Buffer((T.int64(80), T.int64(2560)), "float16"), lv328: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias195: T.Buffer((T.int64(2560),), "float16"), lv2062: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(2560), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1381[v_i // T.int64(8), v_j], lv1382[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1381[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1382[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv328[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv328[v_i0, v_i1, v_k] * var_decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias195[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias195[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], lv2062[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + lv2062[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def func4(lv1387: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), lv1388: T.Buffer((T.int64(80), T.int64(10240)), "float16"), lv2115: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias196: T.Buffer((T.int64(10240),), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240))) + for i, j in T.grid(T.int64(2560), T.int64(10240)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1387[v_i // T.int64(8), v_j], lv1388[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1387[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1388[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(10240), T.int64(2560)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2115[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2115[v_i0, v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias196[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias196[v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply[v_ax0, v_ax1, v_ax2]) + T_multiply[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(T_multiply[v_i0, v_i1, v_i2]) + T.writes(compute[v_i0, v_i1, v_i2]) + compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(compute[v_ax0, v_ax1, v_ax2]) + T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, v_ax2] * T.float32(0.5) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2]) + T.writes(T_add[v_ax0, v_ax1, v_ax2]) + T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + T_multiply_1[v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("T_multiply_2"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_multiply_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def func5(lv1393: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), lv1394: T.Buffer((T.int64(320), T.int64(2560)), "float16"), lv2121: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), linear_bias197: T.Buffer((T.int64(2560),), "float32"), lv329: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv1393[v_i // T.int64(8), v_j], lv1394[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv1393[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1394[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv2121[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2121[v_i0, v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias197[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias197[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv329[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv329[v_ax0, v_ax1, v_ax2] + + @T.prim_func + def func6(lv2509: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), lv2510: T.Buffer((T.int64(320), T.int64(2560)), "float16"), lv4105: T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), linear_bias383: T.Buffer((T.int64(2560),), "float32"), lv701: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + var_decode_intermediate = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560))) + var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16") + for i, j in T.grid(T.int64(10240), T.int64(2560)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv2509[v_i // T.int64(8), v_j], lv2510[v_i // T.int64(32), v_j]) + T.writes(var_decode_intermediate[v_i, v_j]) + var_decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv2509[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv2510[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), T.int64(10240)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv4105[v_i0, v_i1, v_k], var_decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv4105[v_i0, v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2]) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], linear_bias383[v_ax2]) + T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias383[v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2]) + var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2]) + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute_1"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_compute_intermediate[v_i0, v_i1, v_i2]) + T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2]) + var_compute_intermediate_1[v_i0, v_i1, v_i2] = var_compute_intermediate[v_i0, v_i1, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("T_add_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], lv701[v_ax0, v_ax1, v_ax2]) + T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2]) + var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv701[v_ax0, v_ax1, v_ax2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)): + with T.block("compute_2"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) Review Comment: Would be good to delete this for now, and as we add new test cases like `test_normalize.py` we add these modules as before/after in each test file. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
