comaniac opened a new issue, #13508:
URL: https://github.com/apache/tvm/issues/13508
### Expected behavior
The lowering time of the given case should be around 10 seconds.
### Actual behavior
The lowering time is more than 550 seconds.
### Environment
Any environment with commit commit 101e3a4ade226a2b9cdef6437a285af18aef9cf8
(#13217) or later.
### Steps to reproduce
The script:
```pyhton
import time
import tvm
from tvm import topi
class Timer:
def __init__(self, msg):
self.msg = msg
print(f"{msg}...", flush=True)
def __enter__(self):
self.start = time.time()
def __exit__(self, *args):
print(f"{self.msg}...{time.time() - self.start:.2f}s", flush=True)
def resize2d_dx_compute(inp, dy):
"""compute definition for resize2d_dx op"""
size = (64, 32)
layout = "NCHW"
method = "cubic"
coord_trans = "half_pixel"
rounding_method = ""
cubic_alpha = -0.75
cubic_exclude = 0
out_dtype = "float32"
out = topi.image.resize2d(
inp,
(None, None, None, None),
size,
layout,
method,
coord_trans,
rounding_method,
bicubic_alpha=cubic_alpha,
bicubic_exclude=cubic_exclude,
out_dtype=out_dtype,
)
grads = tvm.te.gradient(out, [inp], head=dy)
return grads
inp = tvm.te.placeholder((32, 3, 32, 32), name="inp")
dy = tvm.te.placeholder((32, 3, 64, 32), name="dy")
with Timer("te.gradient"):
grads = resize2d_dx_compute(inp, dy)
# This problem is platform-independent.
with Timer("schedule"):
sch = topi.x86.injective.schedule_injective(grads)
with Timer("lower"):
print(tvm.lower(sch, [inp, dy, grads[0]], simple_mode=True))
```
1. Switch to a commit before 101e3a4ade226a2b9cdef6437a285af18aef9cf8
(#13217) and run the script.
2. Checkout the commit 101e3a4ade226a2b9cdef6437a285af18aef9cf8 (#13217) and
run again.
Here are also the lowered IR without and with this commit:
Without this commit:
```
@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main",
"tir.noalias": True}
buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32],
[]),
dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32],
[]),
resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32),
float32, [32, 3, 32, 32], [])}
buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
for (ax0.ax1.fused: int32, 0, 96) "parallel" {
for (ax2: int32, 0, 32) {
for (ax3.outer: int32, 0, 2) {
resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304],
[])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] =
broadcast(0f32, 16)
for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
for (ax3.inner.s: int32, 0, 16) {
let cse_var_3: float32 = cast(float32,
n1_n1_k3.shifted.shifted)
let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
let cse_var_1: float32 = (((cast(float32,
n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 -
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) ||
((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32,
@tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 -
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0)
|| ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)),
31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 -
max(min((
cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) &&
(((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) -
1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3,
dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 -
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) ==
0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) +
2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 -
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) ||
((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 2), 31), 0)) == 0)))) {
let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) +
(ax3.outer*16)) + ax3.inner.s)
resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4]
+ (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) +
(n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2
== max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 ==
max(min(cast(int32, @tir.floor(cse_var_1, dtype=float3
2)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32,
@tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 ==
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2
== max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((ca
st(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2
== max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))))
|| ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1),
31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32,
@tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) ||
((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31),
0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)),
31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int3
2, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 ==
max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=fl
oat32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2
== max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) ||
(cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31),
0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 ==
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))),
(select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, d
type=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) -
(2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) -
(2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) +
select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32))
+ 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_
var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) +
(-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) -
(2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1,
dtype=float32))))), 0f32) + select((((((a
x2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2
== max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) ||
(cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31),
0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(in
t32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 ==
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))),
(select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) -
(2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) -
(2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select(
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3
- @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) +
(-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1,
dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)),
0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) &&
(cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31),
0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) +
1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32,
@tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))),
((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)) || (cse_var_2 == ma
x(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) ||
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)),
(-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) +
(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2
== max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)),
(((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cs
e_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)),
(((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) -
(-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) +
select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32))
+ 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3
- @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) +
(1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floo
r(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))),
((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))),
(select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)),
(-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) +
(cse_var_3 - @t
ir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 ==
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)),
(((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32),
0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) +
(1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2),
31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3
- @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1,
dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1
- @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))))), 0f32))))
}
}
}
}
}
}
}
}
```
With this commit:
```
@main = primfn(inp_1: handle, dy_1: handle, resize.inp.grad_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "main",
"tir.noalias": True}
buffers = {inp: Buffer(inp_2: Pointer(float32), float32, [32, 3, 32, 32],
[]),
dy: Buffer(dy_2: Pointer(float32), float32, [32, 3, 64, 32],
[]),
resize.inp.grad: Buffer(resize.inp.grad_2: Pointer(float32),
float32, [32, 3, 32, 32], [])}
buffer_map = {inp_1: inp, dy_1: dy, resize.inp.grad_1: resize.inp.grad} {
for (ax0.ax1.fused: int32, 0, 96) "parallel" {
for (ax2: int32, 0, 32) {
for (ax3.outer: int32, 0, 2) {
resize.inp.grad_3: Buffer(resize.inp.grad_2, float32, [98304],
[])[ramp((((ax0.ax1.fused*1024) + (ax2*32)) + (ax3.outer*16)), 1, 16)] =
broadcast(0f32, 16)
for (n0_n0_k2.shifted.shifted: int32, 0, 64) {
for (n1_n1_k3.shifted.shifted: int32, 0, 32) {
for (ax3.inner.s: int32, 0, 16) {
let cse_var_3: float32 = cast(float32,
n1_n1_k3.shifted.shifted)
let cse_var_2: int32 = ((ax3.outer*16) + ax3.inner.s)
let cse_var_1: float32 = (((cast(float32,
n0_n0_k2.shifted.shifted) + 0.5f32)*0.5f32) - 0.5f32)
if ((((((ax2 - max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 -
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) ||
((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 2), 31), 0)) == 0))) || (((ax2 - max(min(cast(int32,
@tir.floor(cse_var_1, dtype=float32)), 31), 0)) == 0) && (((((cse_var_2 -
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0)
|| ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)),
31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) == 0)))) || (((ax2 -
max(min((
cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) == 0) &&
(((((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) -
1), 31), 0)) == 0) || ((cse_var_2 - max(min(cast(int32, @tir.floor(cse_var_3,
dtype=float32)), 31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)) == 0)) || ((cse_var_2 -
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)) ==
0)))) || (((ax2 - max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) +
2), 31), 0)) == 0) && (((((cse_var_2 - max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) == 0) || ((cse_var_2 -
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)) == 0)) ||
((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)) == 0)) || ((cse_var_2 - max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 2), 31), 0)) == 0)))) {
let cse_var_4: int32 = ((((ax0.ax1.fused*1024) + (ax2*32)) +
(ax3.outer*16)) + ax3.inner.s)
resize.inp.grad_3[cse_var_4] = (resize.inp.grad_3[cse_var_4]
+ (dy_3: Buffer(dy_2, float32, [196608], [])[(((ax0.ax1.fused*2048) +
(n0_n0_k2.shifted.shifted*32)) + n1_n1_k3.shifted.shifted)]*(select((((((((ax2
== max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 ==
max(min(cast(int32, @tir.floor(cse_var_1, dtype=float3
2)), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0))) || ((ax2 == max(min(cast(int32,
@tir.floor(cse_var_1, dtype=float32)), 31), 0)) && (cse_var_2 ==
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2
== max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))) || (((((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 == max(min((ca
st(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2
== max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))))
|| ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1),
31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 2), 31), 0))))), (select(((((((ax2 == max(min((cast(int32,
@tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) && (cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) ||
((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31),
0)) && (cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)),
31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int3
2, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))) || (((((ax2 ==
max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0))))), (select((((((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=fl
oat32)) - 1), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) - 1), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) - 1), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2
== max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) ||
(cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31),
0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 ==
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))),
(select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, d
type=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) -
(2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) -
(2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) +
select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32))
+ 1), 31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_
var_3 - @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) +
(-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))), 0f32))*(-0.75f32*(((((cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))) -
(2f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))))) + (cse_var_1 - @tir.floor(cse_var_1,
dtype=float32))))), 0f32) + select((((((a
x2 == max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min(cast(int32, @tir.floor(cse_var_1, dtype=float32)), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min(cast(int32, @tir.floor(cse_var_1,
dtype=float32)), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))), ((select((((cse_var_2
== max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) ||
(cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31),
0))) || (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0))), (select(((cse_var_2 == max(min((cast(in
t32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)) || (cse_var_2 ==
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))),
(select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)), (-0.75f32*(((((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))) -
(2f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) + (cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))), 0f32) + select((cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)), (((1.25f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) -
(2.25f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) + 1f32), 0f32)), 0f32) + select(
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)), (((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3
- @tir.floor(cse_var_3, dtype=float32)))) + (1.5f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))))) - (-0.75f32*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)), ((0.75f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) +
(-0.75f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((1.25f32*(((cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))) - (2.25f32*((cse_var_1 - @tir.floor(cse_var_1,
dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32))))) + 1f32)),
0f32)), 0f32) + select((((((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) + 1), 31), 0)) && (cse_var_2 == max(min((cast(int32,
@tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) &&
(cse_var_2 == max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31),
0)))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) +
1), 31), 0)) && (cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0)))) || ((ax2 == max(min((cast(int32,
@tir.floor(cse_var_1, dtype=float32)) + 1), 31), 0)) && (cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))),
((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)) || (cse_var_2 == ma
x(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0))) ||
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0))), (select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)),
(-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) +
(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2
== max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)),
(((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cs
e_var_3 - @tir.floor(cse_var_3, dtype=float32)))) - (2.25f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))))) + 1f32), 0f32)), 0f32) + select((cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0)),
(((-1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))) + (1.5f32*((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) -
(-0.75f32*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) +
select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32))
+ 2), 31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3
- @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))), 0f32))*(((-1.25f32*(((cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))) +
(1.5f32*((cse_var_1 - @tir.floor(cse_var_1, dtype=float32))*(cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))))) - (-0.75f32*(cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))))), 0f32)), 0f32) + select((((((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1),
31), 0))) || ((ax2 == max(min((cast(int32, @tir.floor(cse_var_1,
dtype=float32)) + 2), 31), 0)) && (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0)))) || ((ax2 ==
max(min((cast(int32, @tir.floor(cse_var_1, dtype=float32)) + 2), 31), 0)) &&
(cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1),
31), 0)))) || ((ax2 == max(min((cast(int32, @tir.floo
r(cse_var_1, dtype=float32)) + 2), 31), 0)) && (cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2), 31), 0)))),
((select((((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0))) || (cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 1), 31), 0))),
(select(((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) - 1), 31), 0)) || (cse_var_2 == max(min(cast(int32,
@tir.floor(cse_var_3, dtype=float32)), 31), 0))), (select((cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) - 1), 31), 0)),
(-0.75f32*(((((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))) - (2f32*((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) +
(cse_var_3 - @t
ir.floor(cse_var_3, dtype=float32)))), 0f32) + select((cse_var_2 ==
max(min(cast(int32, @tir.floor(cse_var_3, dtype=float32)), 31), 0)),
(((1.25f32*(((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))) - (2.25f32*((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32))))) + 1f32),
0f32)), 0f32) + select((cse_var_2 == max(min((cast(int32, @tir.floor(cse_var_3,
dtype=float32)) + 1), 31), 0)), (((-1.25f32*(((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32)))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))) +
(1.5f32*((cse_var_3 - @tir.floor(cse_var_3, dtype=float32))*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))))) - (-0.75f32*(cse_var_3 -
@tir.floor(cse_var_3, dtype=float32)))), 0f32)), 0f32) + select((cse_var_2 ==
max(min((cast(int32, @tir.floor(cse_var_3, dtype=float32)) + 2),
31), 0)), ((0.75f32*(((cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3, dtype=float32)))*(cse_var_3
- @tir.floor(cse_var_3, dtype=float32)))) + (-0.75f32*((cse_var_3 -
@tir.floor(cse_var_3, dtype=float32))*(cse_var_3 - @tir.floor(cse_var_3,
dtype=float32))))), 0f32))*((0.75f32*(((cse_var_1 - @tir.floor(cse_var_1,
dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1, dtype=float32)))*(cse_var_1
- @tir.floor(cse_var_1, dtype=float32)))) + (-0.75f32*((cse_var_1 -
@tir.floor(cse_var_1, dtype=float32))*(cse_var_1 - @tir.floor(cse_var_1,
dtype=float32)))))), 0f32))))
}
}
}
}
}
}
}
}
```
cc @Lunderberg @masahi
### Triage
* needs-triage
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]