Thanks for your comments:)
[quote="areusch, post:3, topic:11807"]
could you say more here? is this a Relay-level thing or a TIR thing? presuming
you’ve implemented this as a pass, how do you plan to ensure that the
Relay-level pass makes the same scheduling decision as the TIR pass?
[/quote]
Perhaps I could take a fake example on Conv2d to describe it:
fn (%arg0: Tensor[(1, 32, 224, 224), int8], %nn.conv2d_arg: Tensor[(32, 3,
7, 7), int8]) {
%conv_fn = fn (%data: Tensor[(1, 3, 224, 224), int8], %weight:
Tensor[(32, 3, 7, 7), int8], Primitive=1) {
nn.conv2d(%data, %weight, padding=[1, 1, 1, 1], kernel_size=[7, 7],
out_dtype="int32")
};
%conv_fn(%arg0, %nn.conv2d_arg)
}
and the coresponding PrimFunc for primitive call `%conv_fn` would be like
```python
@T.prim_func
def main(x: T.Buffer[...], weight: T.Buffer[(32, 3, 7, 7), "int8"], y:
T.Buffer[...]) -> None:
# body
```
Assume to utilize the specific hardware, we want to arrange I/O channels into
4*4 tiles. There are extra two notes:
- We get to know the "best" weight layout until a TIR schedule/tuning is done.
- The required layout is out of scope of common representations like "OIHW",
"OHWI", etc.
The TIR schedule part would do following transformation on `weight`:
```python
o, i, h, w = s.get_read_buffer_axes(conv_block)
o_outer, o_inner = s.buffer_split(o, factor=4) # [32, 3, 7, 7] -> [8, 4, 3, 7,
7]
i_outer, i_inner = s.buffer_split(i, factor=4) # [8, 4, 3, 7, 7] -> [8, 4, 1,
4, 7, 7]
s.buffer_reorder(o_outer, o_inner, i_outer, i_inner, h, w) # [8, 4, 1, 4, 7,
7] -> [8, 1, 4, 4, 7, 7]
```
Above we use a set of extended TensorIR primitives, but they can just be seen
as sugars of ongoing schedule primitive `transform_layout`:
https://github.com/apache/tvm-rfcs/pull/39
The point is that they are not arbitary index remappings (compare to a general
`transform_layout`). We ensure every such schedule step takes exact equivalent
relay transformations.
In TIR schedule phase, we trace every buffer layout change on function param
buffer (we can do that since they are what we implement), generate the
transform (&& reverse transform) in relay on each step, and finally compose
them into single layout transform (&& reverse transform) functions in relay.
For the used example, it would be:
- `s.buffer_split(o, factor=4)`
- x -> relay.reshape(x, [-1, 4, 3, 7, 7])
- (reverse) x -> relay.reshape(x, [32, 3, 7, 7])
- `s.buffer_split(i, factor=4)`
- x -> relay.reshape(relay.nn.pad(x, [..., (0, 1), ...]), [8, 4, -1, 4, 7, 7])
- (reverse) x -> relay.strided_slice(relay.reshape(x, [8, 4, 4, 7, 7]),
begin=..., end=...)
- `s.buffer_reorder(...)`
- x -> relay.transpose(x, [...])
- (reverse) x -> relay.transpose(x, [...])
Finally all transforms (&& reverse transforms) are composed into two
`relay.Function` objects to rewrite relay-level layouts, which accepts original
relay params, returns updated params tuple:
fn (%p0: Tensor[..., int8], %p1: Tensor[(32, 3, 7, 7), int8]) {
%0 = reshape(%p1, newshape=[...]);
%1 = nn.pad(%0, pad_width=[...]);
%2 = reshape(%1, newshape=[...]);
%3 = transpose(%2, axes=[...]);
(%p0, %3)
}
and the reverse direction is:
fn (%p0: Tensor[..., int8], %p1: Tensor[(8, 4, 1, 4, 7, 7), int8]) {
%0 = transpose(%p1, axes=[...]);
%1 = reshape(%0, newshape=[...]);
%2 = strided_slice(%1, begin=[...], end=[...], strides=[...]);
%3 = reshape(%2, newshape=[32, 3, 7, 7]);
(%p0, %3)
}
A relay pass now can perform "pre"-schedule for each primitive function, fetch
the layout transform functions from schedule result, and perform relay-level
layout updation. Finally, an extra `FoldConstants` could eliminate all extra
transformations out of primitive calls typically.
fn (%arg0: Tensor[(1, 32, 224, 224), int8], %nn.conv2d_arg: Tensor[(32, 3,
7, 7), int8]) {
%0 = reshape(%nn.conv2d_arg, newshape=[...]);
%1 = nn.pad(%0, pad_width=[...]);
%2 = reshape(%1, newshape=[...]);
%3 = transpose(%2, axes=[...]);
%conv_fn = fn (%data: Tensor[(1, 3, 224, 224), int8], %weight: Tensor[(8,
4, 1, 4, 7, 7), int8], Primitive=1, DevicePrimFuncKey=873487) {
%4 = transpose(%weight, axes=[...]);
%5 = reshape(%4, newshape=[...]);
%6 = strided_slice(%5, begin=[...], end=[...], strides=[...]);
%7 = reshape(%6, newshape=[32, 3, 7, 7]);
nn.conv2d(%data, %7, padding=[1, 1, 1, 1], kernel_size=[7, 7],
out_dtype="int32");
};
%conv_fn(%arg0, %3)
}
The actual params are transformed before call into `%conv_fn` and the formal
params are reversed within `%conv_fn`'s body. Why we need reverse transforms is
that we currently can not represent a "lowered" function call in relay (correct
me). It is a workaround for us to keep a valid primitive function body, that
is, the relay module after pass can still be safely evaluated on a CPU.
All things described are only targeted to weights (free tensors) now. We check
that a tensor produced/consumed by other relay calls should not get
transformed. For input and output layouts, we find relay `ConvertLayout` can
cover the currently demands. However, I think there is no essential difference
between "appliable functions to transform layout" and a simple tag like "NCHW"
on a input/output, it is possible to rewrite the input/output with the same
machanism.
One remaining issue here is that we have to hack the `CompileEngine`(now
`TECompiler`) to cache and reuse the previously scheduled PrimFuncs. Very glad
to know if existing machanisms (like `relay_to_tir`?) can help us
:slight_smile: cc @areusch
---
[Visit
Topic](https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807/4)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/b8849695b0789156eaf9d2d0cd3cf1b0c03ed4539a781710dad9b544cb5ece17).