gemini-code-assist[bot] commented on code in PR #18410:
URL: https://github.com/apache/tvm/pull/18410#discussion_r2480208310
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -5890,12 +5890,18 @@ def main(x: R.Tensor((5, 3), dtype="float32")) ->
R.Tuple(R.Tensor((5, 3), dtype
lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
x, axis=1, descending=True, dtype="int32"
)
- gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,)
+ lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x,
lv, axis=1)
+ lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5,
3), dtype="int32")) = (
+ lv1,
+ lv,
+ )
+ lv3: R.Tensor((5, 3), dtype="int32") = lv2[1]
+ gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,)
Review Comment:

The generated IR for `argsort` is unnecessarily complex. It computes both
sorted values (`lv1`) and indices (`lv`), then creates a tuple (`lv2`), only to
extract the indices (`lv3`). The sorted values are computed via
`gather_elements` but are never used for the final result. This seems to be a
result of decomposing `torch.argsort` into `torch.sort` and then taking the
indices.
While a Dead Code Elimination (DCE) pass might clean this up, it would be
more efficient to have a more direct translation for `argsort` that doesn't
compute the sorted values if they are not needed.
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6216,13 +6266,24 @@ def main(
input: R.Tensor((9, 9), dtype="float32")
) -> R.Tuple(R.Tensor((9,), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625,
0.125, dtype="float32")
- gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
+ lv: R.Tensor((9,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(9), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4,
"int64"))
+ lv2: R.Tensor((9,), dtype="float32") = R.astype(lv,
dtype="float32")
+ lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2,
R.const(0.125, "float32"))
+ lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0,
"float32"))
+ lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8,
"int64"), lv)
+ lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5,
dtype="float32")
+ lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6,
R.const(0.125, "float32"))
+ lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0,
"float32"), lv7)
+ lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8)
+ gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,)
Review Comment:

The generated IR for `linspace` is highly inefficient. It computes two
expressions, `lv4` and `lv8`, which are mathematically equivalent for the given
inputs (`i * 0.125` for `i` in `[0, 8]`). Then it uses `R.where` to select
between these identical values. The entire `where` operation and the
computation of `lv5` through `lv8` are redundant. The IR could be simplified to
just compute `lv4` and use that as the result. This suggests an issue in the
PyTorch decomposition logic for `linspace` that should be investigated.
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6187,21 +6220,38 @@ def forward(self, x):
@tvm.script.ir_module
class Expected1:
@R.function
- def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((),
dtype="float32")):
+ def main(x: R.Tensor((4, 3), dtype="float32")) ->
R.Tuple(R.Tensor((4,), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x,
axis=-1)
- lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss(
- lv,
- targets=R.const([0, 1, 2, 1], dtype="int64"),
- reduction="mean",
- ignore_index=-100,
+ lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x,
axis=1)
+ lv1: R.Tensor((4,), dtype="bool") = R.not_equal(
+ R.const([0, 1, 2, 1], dtype="int64"), R.const(-100,
"int64")
+ )
+ lv2: R.Tensor((), dtype="int64") = R.const(0, "int64")
+ lv3: R.Tensor((4,), dtype="int64") = R.where(
+ lv1, R.const([0, 1, 2, 1], dtype="int64"), lv2
)
- gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
+ lv4: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv3,
axis=[1])
+ lv5: R.Tensor((4, 1), dtype="float32") = R.gather_elements(lv,
lv4, axis=1)
+ lv6: R.Tensor((4,), dtype="float32") = R.squeeze(lv5, axis=[1])
+ lv7: R.Tensor((4,), dtype="float32") = R.negative(lv6)
+ lv8: R.Tensor((4,), dtype="bool") = R.not_equal(
+ R.const([0, 1, 2, 1], dtype="int64"), R.const(-100,
"int64")
+ )
+ lv9: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+ lv10: R.Tensor((4,), dtype="float32") = R.where(lv8, lv7, lv9)
+ lv11: R.Tensor((4,), dtype="bool") = R.not_equal(
+ R.const([0, 1, 2, 1], dtype="int64"), R.const(-100,
"int64")
+ )
+ lv12: R.Tensor((4,), dtype="bool") = R.sum(lv11, axis=[],
keepdims=False)
+ lv13: R.Tensor((4,), dtype="float32") = R.astype(lv12,
dtype="float32")
+ lv14: R.Tensor((4,), dtype="float32") = R.sum(lv10, axis=[],
keepdims=False)
+ lv15: R.Tensor((4,), dtype="float32") = R.divide(lv14, lv13)
+ gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv15,)
R.output(gv)
return gv
Review Comment:

The expected IR for `test_cross_entropy` has incorrect shape annotations.
`torch.nn.CrossEntropyLoss` with the default `reduction='mean'` should return a
scalar tensor. However, the return type of the `main` function is annotated as
`R.Tuple(R.Tensor((4,), dtype="float32"))`.
Looking at the IR, `lv12` and `lv14` are results of `R.sum` with `axis=[]`,
which should produce scalar tensors (shape `()`), but they are annotated with
shape `(4,)`. Consequently, `lv15` (the final result) is also annotated with
shape `(4,)` instead of `()`. The function signature and intermediate type
annotations should be corrected to reflect that a scalar is being computed.
##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1722,6 +1722,9 @@ def _split(self, node: fx.Node) -> relax.Var:
def _squeeze(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
+ # Support both "dim" and "dims" parameters
+ if dim is None:
+ dim = node.kwargs.get("dims", None)
Review Comment:

The logic to get `dim` can be simplified into a single line by chaining
`dict.get` calls. This makes the code more concise and easier to read.
```suggestion
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
node.kwargs.get("dims"))
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]