Hello,I am trying to fuse one layer convolution computation and their relu
result into next layer convolution computation. I tried two methods, one is to
use te.sum expression as a parameter of another te.sum, and the other is to use
s.compute_inline(), but both fail. I would like to know if it is possible to
combine two reduce stages (te.sum) into one reduce stage in te, if not, can
relay and tir complete the expression of this function.here is current tir
without fusion:
primfn(args: handle, arg_type_ids: handle, num_args: int32,
out_ret_value: handle, out_ret_tcode: handle, resource_handle: handle) -> int32`
attr = {"target": meta[Target][0], "tir.noalias": True, "global_symbol":
"myfunc_fusion", "from_legacy_te_schedule": True, "tir.is_entry_func": True,
"calling_conv": 1} {
assert((num_args == 4), "myfunc_fusion: num_args should be 4")
let arg0: handle = @tir.tvm_struct_get(args, 0, 12, dtype=handle)
let arg0.code: int32 = (int32*)arg_type_ids[0]
let arg1: handle = @tir.tvm_struct_get(args, 1, 12, dtype=handle)
let arg1.code: int32 = (int32*)arg_type_ids[1]
let arg2: handle = @tir.tvm_struct_get(args, 2, 12, dtype=handle)
let arg2.code: int32 = (int32*)arg_type_ids[2]
let arg3: handle = @tir.tvm_struct_get(args, 3, 12, dtype=handle)
let arg3.code: int32 = (int32*)arg_type_ids[3]
let A: Pointer(float32) = @tir.tvm_struct_get(arg0, 0, 1, dtype=handle)
attr [A] "storage_alignment" = 128;
let arg0.shape: handle = @tir.tvm_struct_get(arg0, 0, 2, dtype=handle)
let arg0.strides: handle = @tir.tvm_struct_get(arg0, 0, 3, dtype=handle)
let dev_id: int32 = @tir.tvm_struct_get(arg0, 0, 9, dtype=int32)
let W: Pointer(float32) = @tir.tvm_struct_get(arg1, 0, 1, dtype=handle)
attr [W] "storage_alignment" = 128;
let arg1.shape: handle = @tir.tvm_struct_get(arg1, 0, 2, dtype=handle)
let arg1.strides: handle = @tir.tvm_struct_get(arg1, 0, 3, dtype=handle)
let W_2: Pointer(float32) = @tir.tvm_struct_get(arg2, 0, 1, dtype=handle)
attr [W_2] "storage_alignment" = 128;
let arg2.shape: handle = @tir.tvm_struct_get(arg2, 0, 2, dtype=handle)
let arg2.strides: handle = @tir.tvm_struct_get(arg2, 0, 3, dtype=handle)
let C: Pointer(float32) = @tir.tvm_struct_get(arg3, 0, 1, dtype=handle)
attr [C] "storage_alignment" = 128;
let arg3.shape: handle = @tir.tvm_struct_get(arg3, 0, 2, dtype=handle)
let arg3.strides: handle = @tir.tvm_struct_get(arg3, 0, 3, dtype=handle)
assert(((((arg0.code == 3) || (arg0.code == 13)) || (arg0.code == 7)) ||
(arg0.code == 4)), "myfunc_fusion: Expect arg[0] to be pointer")
assert(((((arg1.code == 3) || (arg1.code == 13)) || (arg1.code == 7)) ||
(arg1.code == 4)), "myfunc_fusion: Expect arg[1] to be pointer")
assert(((((arg2.code == 3) || (arg2.code == 13)) || (arg2.code == 7)) ||
(arg2.code == 4)), "myfunc_fusion: Expect arg[2] to be pointer")
assert(((((arg3.code == 3) || (arg3.code == 13)) || (arg3.code == 7)) ||
(arg3.code == 4)), "myfunc_fusion: Expect arg[3] to be pointer")
assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is
expected to equal 4")
assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is
expected to equal 4")
assert((((@tir.tvm_struct_get(arg0, 0, 5, dtype=uint8) == 2u8) &&
(@tir.tvm_struct_get(arg0, 0, 6, dtype=uint8) == 32u8)) &&
(@tir.tvm_struct_get(arg0, 0, 7, dtype=uint16) == 1u16)), "arg0.dtype is
expected to be float32")
assert((56 == cast(int32, (int64*)arg0.shape[0])), "Argument
arg0.shape[0] has an unsatisfied constraint: (56 == int32(arg0.shape[0]))")
assert((56 == cast(int32, (int64*)arg0.shape[1])), "Argument
arg0.shape[1] has an unsatisfied constraint: (56 == int32(arg0.shape[1]))")
assert((64 == cast(int32, (int64*)arg0.shape[2])), "Argument
arg0.shape[2] has an unsatisfied constraint: (64 == int32(arg0.shape[2]))")
assert((3 == cast(int32, (int64*)arg0.shape[3])), "Argument arg0.shape[3]
has an unsatisfied constraint: (3 == int32(arg0.shape[3]))")
{
if [email protected](arg0.strides, dtype=bool) {
assert(((((1 == cast(int32, (int64*)arg0.strides[3])) && (3 ==
cast(int32, (int64*)arg0.strides[2]))) && (192 == cast(int32,
(int64*)arg0.strides[1]))) && (10752 == cast(int32, (int64*)arg0.strides[0]))),
"arg0.strides: expected to be compact array")
0
}
assert((0u64 == @tir.tvm_struct_get(arg0, 0, 8, dtype=uint64)),
"Argument arg0.byte_offset has an unsatisfied constraint: ((uint64)0 ==
tir.tvm_struct_get(arg0, 0, 8))")
assert((1 == @tir.tvm_struct_get(arg0, 0, 10, dtype=int32)), "Argument
arg0.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg0,
0, 10))")
assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim
is expected to equal 4")
assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim
is expected to equal 4")
assert((((@tir.tvm_struct_get(arg1, 0, 5, dtype=uint8) == 2u8) &&
(@tir.tvm_struct_get(arg1, 0, 6, dtype=uint8) == 32u8)) &&
(@tir.tvm_struct_get(arg1, 0, 7, dtype=uint16) == 1u16)), "arg1.dtype is
expected to be float32")
assert((3 == cast(int32, (int64*)arg1.shape[0])), "Argument
arg1.shape[0] has an unsatisfied constraint: (3 == int32(arg1.shape[0]))")
assert((3 == cast(int32, (int64*)arg1.shape[1])), "Argument
arg1.shape[1] has an unsatisfied constraint: (3 == int32(arg1.shape[1]))")
assert((64 == cast(int32, (int64*)arg1.shape[2])), "Argument
arg1.shape[2] has an unsatisfied constraint: (64 == int32(arg1.shape[2]))")
assert((64 == cast(int32, (int64*)arg1.shape[3])), "Argument
arg1.shape[3] has an unsatisfied constraint: (64 == int32(arg1.shape[3]))")
{
if [email protected](arg1.strides, dtype=bool) {
assert(((((1 == cast(int32, (int64*)arg1.strides[3])) && (64 ==
cast(int32, (int64*)arg1.strides[2]))) && (4096 == cast(int32,
(int64*)arg1.strides[1]))) && (12288 == cast(int32, (int64*)arg1.strides[0]))),
"arg1.strides: expected to be compact array")
0
}
assert((0u64 == @tir.tvm_struct_get(arg1, 0, 8, dtype=uint64)),
"Argument arg1.byte_offset has an unsatisfied constraint: ((uint64)0 ==
tir.tvm_struct_get(arg1, 0, 8))")
assert((1 == @tir.tvm_struct_get(arg1, 0, 10, dtype=int32)),
"Argument arg1.device_type has an unsatisfied constraint: (1 ==
tir.tvm_struct_get(arg1, 0, 10))")
assert((dev_id == @tir.tvm_struct_get(arg1, 0, 9, dtype=int32)),
"Argument arg1.device_id has an unsatisfied constraint: (dev_id ==
tir.tvm_struct_get(arg1, 0, 9))")
assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)),
"arg2.ndim is expected to equal 4")
assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)),
"arg2.ndim is expected to equal 4")
assert((((@tir.tvm_struct_get(arg2, 0, 5, dtype=uint8) == 2u8) &&
(@tir.tvm_struct_get(arg2, 0, 6, dtype=uint8) == 32u8)) &&
(@tir.tvm_struct_get(arg2, 0, 7, dtype=uint16) == 1u16)), "arg2.dtype is
expected to be float32")
assert((3 == cast(int32, (int64*)arg2.shape[0])), "Argument
arg2.shape[0] has an unsatisfied constraint: (3 == int32(arg2.shape[0]))")
assert((3 == cast(int32, (int64*)arg2.shape[1])), "Argument
arg2.shape[1] has an unsatisfied constraint: (3 == int32(arg2.shape[1]))")
assert((64 == cast(int32, (int64*)arg2.shape[2])), "Argument
arg2.shape[2] has an unsatisfied constraint: (64 == int32(arg2.shape[2]))")
assert((64 == cast(int32, (int64*)arg2.shape[3])), "Argument
arg2.shape[3] has an unsatisfied constraint: (64 == int32(arg2.shape[3]))")
{
if [email protected](arg2.strides, dtype=bool) {
assert(((((1 == cast(int32, (int64*)arg2.strides[3])) && (64 ==
cast(int32, (int64*)arg2.strides[2]))) && (4096 == cast(int32,
(int64*)arg2.strides[1]))) && (12288 == cast(int32, (int64*)arg2.strides[0]))),
"arg2.strides: expected to be compact array")
0
}
assert((0u64 == @tir.tvm_struct_get(arg2, 0, 8, dtype=uint64)),
"Argument arg2.byte_offset has an unsatisfied constraint: ((uint64)0 ==
tir.tvm_struct_get(arg2, 0, 8))")
assert((1 == @tir.tvm_struct_get(arg2, 0, 10, dtype=int32)),
"Argument arg2.device_type has an unsatisfied constraint: (1 ==
tir.tvm_struct_get(arg2, 0, 10))")
assert((dev_id == @tir.tvm_struct_get(arg2, 0, 9, dtype=int32)),
"Argument arg2.device_id has an unsatisfied constraint: (dev_id ==
tir.tvm_struct_get(arg2, 0, 9))")
assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)),
"arg3.ndim is expected to equal 4")
assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)),
"arg3.ndim is expected to equal 4")
assert((((@tir.tvm_struct_get(arg3, 0, 5, dtype=uint8) == 2u8) &&
(@tir.tvm_struct_get(arg3, 0, 6, dtype=uint8) == 32u8)) &&
(@tir.tvm_struct_get(arg3, 0, 7, dtype=uint16) == 1u16)), "arg3.dtype is
expected to be float32")
assert((54 == cast(int32, (int64*)arg3.shape[0])), "Argument
arg3.shape[0] has an unsatisfied constraint: (54 == int32(arg3.shape[0]))")
assert((54 == cast(int32, (int64*)arg3.shape[1])), "Argument
arg3.shape[1] has an unsatisfied constraint: (54 == int32(arg3.shape[1]))")
assert((64 == cast(int32, (int64*)arg3.shape[2])), "Argument
arg3.shape[2] has an unsatisfied constraint: (64 == int32(arg3.shape[2]))")
assert((3 == cast(int32, (int64*)arg3.shape[3])), "Argument
arg3.shape[3] has an unsatisfied constraint: (3 == int32(arg3.shape[3]))")
{
if [email protected](arg3.strides, dtype=bool) {
assert(((((1 == cast(int32, (int64*)arg3.strides[3])) && (3 ==
cast(int32, (int64*)arg3.strides[2]))) && (192 == cast(int32,
(int64*)arg3.strides[1]))) && (10368 == cast(int32, (int64*)arg3.strides[0]))),
"arg3.strides: expected to be compact array")
0
}
assert((0u64 == @tir.tvm_struct_get(arg3, 0, 8, dtype=uint64)),
"Argument arg3.byte_offset has an unsatisfied constraint: ((uint64)0 ==
tir.tvm_struct_get(arg3, 0, 8))")
assert((1 == @tir.tvm_struct_get(arg3, 0, 10, dtype=int32)),
"Argument arg3.device_type has an unsatisfied constraint: (1 ==
tir.tvm_struct_get(arg3, 0, 10))")
assert((dev_id == @tir.tvm_struct_get(arg3, 0, 9, dtype=int32)),
"Argument arg3.device_id has an unsatisfied constraint: (dev_id ==
tir.tvm_struct_get(arg3, 0, 9))")
attr [0] "compute_scope" = "myfunc_fusion_compute_";
attr [R: Pointer(global float32)] "storage_alignment" = 128 {
let R = @tir.TVMBackendAllocWorkspace(1, dev_id, 2239488u64, 2,
32, dtype=handle)
{
if @tir.isnullptr(R, dtype=bool) {
@tir.tvm_throw_last_error(, dtype=int32)
}
allocate(B: Pointer(global float32), float32, [1]),
storage_scope = global {
for (yy: int32, 0, 54) {
for (xx: int32, 0, 54) {
for (cc: int32, 0, 64) {
for (batch: int32, 0, 3) {
B[0] = 0f32
for (ry: int32, 0, 3) {
for (rx: int32, 0, 3) {
for (rc: int32, 0, 64) {
B[0] = @tir.call_llvm_pure_intrin(134u32,
3u32, (float32*)A[((((((yy*10752) + (ry*10752)) + (xx*192)) + (rx*192)) +
(rc*3)) + batch)], (float32*)W[((((ry*12288) + (rx*4096)) + (rc*64)) + cc)],
(float32*)B[0], dtype=float32)
}
}
}
R[((((yy*10368) + (xx*192)) + (cc*3)) + batch)] =
max(0f32, (float32*)B[0])
}
}
}
}
for (yy_1: int32, 0, 54) {
for (xx_1: int32, 0, 54) {
for (ff: int32, 0, 64) {
for (nn: int32, 0, 3) {
C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)] =
0f32
for (ry_2: int32, 0, 3) {
for (rx_2: int32, 0, 3) {
for (rc_2: int32, 0, 64) {
C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) +
nn)] = @tir.call_llvm_pure_intrin(134u32, 3u32, (float32*)R[((((((yy_1*10368) +
(ry_2*10368)) + (xx_1*192)) + (rx_2*192)) + (rc_2*3)) + nn)],
(float32*)W_2[((((ry_2*12288) + (rx_2*4096)) + (rc_2*64)) + ff)],
(float32*)C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)], dtype=float32)
}
}
}
}
}
}
}
}
}
if (@tir.TVMBackendFreeWorkspace(1, dev_id, R, dtype=int32) !=
0) {
@tir.tvm_throw_last_error(, dtype=int32)
}
}
}
}
}
}
}
---
[Visit
Topic](https://discuss.tvm.apache.org/t/can-one-reduce-stage-fuse-into-another-reduce-stage/12367/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/75e39fd21722d84c7a39212d0f263c73b6db84798a8e4a1929f107fc6ce53cb1).