This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6248b5db43 [Relax][Torch] Fixed issues related to sum op when without
dim and keep dim (#18583)
6248b5db43 is described below
commit 6248b5db43505fbcfb13cc289d11877d5d2649e8
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Sat Dec 13 14:29:23 2025 +0700
[Relax][Torch] Fixed issues related to sum op when without dim and keep dim
(#18583)
## Issue 1: Without Dim
### Summary:
In _sum function (BaseFXGraphImporter), after retrieve_args, args[1] =
[] and still pass into relax.op.sum so the result is incorrect.
### Steps to Reproduce
- Module
```
class SumWithoutDim(nn.Module):
def forward(self, x):
return torch.sum(x)
```
```
class Module:
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2,
3), dtype="float32")):
with R.dataflow():
lv: R.Tensor((2, 3), dtype="float32") = R.sum(x, axis=[],
keepdims=False)
gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
- Result:
Input: tensor([[1., 1., 1.], [1., 1., 1.]])
Torch output: tensor(6.)
Torch output shape: torch.Size([])
TVM output: [[1. 1. 1.] [1. 1. 1.]]
TVM output shape: (2, 3)
### Expected
```
class Module:
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((),
dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.sum(x, axis=None,
keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
- Result: TVM output: 6.0; TVM output shape: ()
## Issue 2: Keep Dim
### Summary:
In _sum function (BaseFXGraphImporter), previously keepdim value get
only from node.kwargs and no pass into relax.op.sum. Now keepdim get
more from args[2] and pass into.
### Steps to Reproduce
- Module
```
class SumKeepDim(nn.Module):
def forward(self, x):
return torch.sum(x, dim=1, keepdim=True)
```
```
class Module:
def main(x: R.Tensor((2, 3), dtype="float32")) ->
R.Tuple(R.Tensor((2,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((2,), dtype="float32") = R.sum(x, axis=[1],
keepdims=False)
gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
- Result:
Input: tensor([[1., 1., 1.], [1., 1., 1.]])
Torch output: tensor([[3.], [3.]])
Torch output shape: torch.Size([2, 1])
TVM VM output: [3. 3.]
TVM VM output shape: (2,)
### Expected
```
class Module:
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2,
1), dtype="float32")):
with R.dataflow():
lv: R.Tensor((2, 1), dtype="float32") = R.sum(x, axis=[1],
keepdims=True)
gv: R.Tuple(R.Tensor((2, 1), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
- Result: TVM output: [[3.] [3.]] ;TVM output shape: (2, 1)
---
.../frontend/torch/base_fx_graph_translator.py | 10 +++--
.../relax/test_frontend_from_exported_program.py | 48 +++++++++++++++++++---
2 files changed, 48 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 47eb666210..f7d54a6216 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1628,10 +1628,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
def _sum(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
- keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
- if len(args) == 1:
- return self.block_builder.emit(relax.op.sum(args[0],
keepdims=keepdim))
- return self.block_builder.emit(relax.op.sum(args[0], args[1]))
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ if isinstance(dim, (list, tuple)) and len(dim) == 0:
+ dim = None
+ keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+ return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))
def _var(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 01e16e7564..4a84b50cc9 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4945,6 +4945,14 @@ def test_sum():
def forward(self, x):
return torch.sum(x, (2, 1))
+ class SumKeepDim(Module):
+ def forward(self, x):
+ return torch.sum(x, (2, 1), keepdim=True)
+
+ class SumWithoutDim(Module):
+ def forward(self, x):
+ return torch.sum(x)
+
@tvm.script.ir_module
class expected1:
@R.function
@@ -4958,8 +4966,36 @@ def test_sum():
R.output(gv)
return gv
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 1, 1, 4), dtype="float32") = R.sum(
+ inp_0, axis=[2, 1], keepdims=True
+ )
+ gv: R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.sum(inp_0, axis=None,
keepdims=False)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
verify_model(Sum(), example_args, {}, expected1)
+ verify_model(SumKeepDim(), example_args, {}, expected2)
+ verify_model(SumWithoutDim(), example_args, {}, expected3)
def test_argmax_argmin():
@@ -7840,7 +7876,7 @@ def test_cross_entropy():
@tvm.script.ir_module
class Expected1:
@R.function
- def main(x: R.Tensor((4, 3), dtype="float32")) ->
R.Tuple(R.Tensor((4,), dtype="float32")):
+ def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((),
dtype="float32")):
with R.dataflow():
lv: R.Tensor((4, 3), dtype="float32") = R.astype(x,
dtype="float32")
lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv,
axis=1)
@@ -7863,11 +7899,11 @@ def test_cross_entropy():
lv12: R.Tensor((4,), dtype="bool") = R.not_equal(
R.const([0, 1, 2, 1], dtype="int64"), R.const(-100,
"int64")
)
- lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[],
keepdims=False)
- lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13,
dtype="float32")
- lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[],
keepdims=False)
- lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14)
- gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,)
+ lv13: R.Tensor((), dtype="bool") = R.sum(lv12, axis=None,
keepdims=False)
+ lv14: R.Tensor((), dtype="float32") = R.astype(lv13,
dtype="float32")
+ lv15: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None,
keepdims=False)
+ lv16: R.Tensor((), dtype="float32") = R.divide(lv15, lv14)
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv16,)
R.output(gv)
return gv