gemini-code-assist[bot] commented on code in PR #18417:
URL: https://github.com/apache/tvm/pull/18417#discussion_r2485351378
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -4580,47 +4580,51 @@ def forward(self, x, y):
class Expected0:
@R.function
def main(
- inp_0: R.Tensor((2, 3), dtype="float32"),
- inp_1: R.Tensor((2, 3), dtype="float32"),
+ x: R.Tensor((2, 3), dtype="float32"),
+ y: R.Tensor((2, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0,
inp_1), axis=0)
- gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+ lv: R.Tensor((4, 3), dtype="float32") = R.concat((x, y),
axis=0)
+ lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv,
R.shape([2, 2, 3]))
+ gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
Review Comment:

The decomposition of `torch.stack` with `axis=0` into `concat` and `reshape`
is valid for contiguous tensors, but it's not as canonical as using
`expand_dims` and `concat`. The decomposition for `axis=-1` in `Expected3` uses
`expand_dims` and `concat`, which is a more direct translation of the `stack`
operation's semantics of inserting a new dimension. For consistency and
clarity, it would be better to use the same `expand_dims` and `concat` approach
for `axis=0`. This would make the decomposition logic more robust and easier to
understand across different axes.
For example:
```python
lv: R.Tensor((1, 2, 3), dtype="float32") = R.expand_dims(x, axis=0)
lv1: R.Tensor((1, 2, 3), dtype="float32") = R.expand_dims(y, axis=0)
lv2: R.Tensor((2, 2, 3), dtype="float32") = R.concat((lv, lv1), axis=0)
gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv2,)
```
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -4580,47 +4580,51 @@ def forward(self, x, y):
class Expected0:
@R.function
def main(
- inp_0: R.Tensor((2, 3), dtype="float32"),
- inp_1: R.Tensor((2, 3), dtype="float32"),
+ x: R.Tensor((2, 3), dtype="float32"),
+ y: R.Tensor((2, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0,
inp_1), axis=0)
- gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+ lv: R.Tensor((4, 3), dtype="float32") = R.concat((x, y),
axis=0)
+ lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv,
R.shape([2, 2, 3]))
+ gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
R.output(gv)
return gv
@I.ir_module
class Expected1:
@R.function
def main(
- inp_0: R.Tensor((2, 3), dtype="float32"),
- inp_1: R.Tensor((2, 3), dtype="float32"),
+ x: R.Tensor((2, 3), dtype="float32"),
+ y: R.Tensor((2, 3), dtype="float32"),
) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0,
inp_1), axis=1)
- gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,)
+ lv: R.Tensor((2, 6), dtype="float32") = R.concat((x, y),
axis=1)
+ lv1: R.Tensor((2, 2, 3), dtype="float32") = R.reshape(lv,
R.shape([2, 2, 3]))
+ gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv1,)
Review Comment:

Similar to the `axis=0` case, the decomposition of `torch.stack` with
`axis=1` into `concat` and `reshape` is less canonical than using `expand_dims`
and `concat`. Using a consistent decomposition strategy across all axes would
improve the robustness and readability of the translated IR.
For example:
```python
lv: R.Tensor((2, 1, 3), dtype="float32") = R.expand_dims(x, axis=1)
lv1: R.Tensor((2, 1, 3), dtype="float32") = R.expand_dims(y, axis=1)
lv2: R.Tensor((2, 2, 3), dtype="float32") = R.concat((lv, lv1), axis=1)
gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv2,)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]