This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cdebac1961 [Relax][PyTorch] Fix torch 2.6 compatibility issues (#17807)
cdebac1961 is described below
commit cdebac19617c61f00b7296d0951136f8dfdafe50
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Apr 4 13:16:43 2025 +0900
[Relax][PyTorch] Fix torch 2.6 compatibility issues (#17807)
* fix test_flatten
* re-enable test_split
* fix test_to_copy
* re-enable test_batchnorm2d
---
.../frontend/torch/base_fx_graph_translator.py | 34 ++++++++++++++++++++++
.../frontend/torch/exported_program_translator.py | 4 +++
python/tvm/relax/frontend/torch/fx_translator.py | 18 ------------
.../relax/test_frontend_from_exported_program.py | 19 ++----------
4 files changed, 40 insertions(+), 35 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 890f925079..d99411bd56 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel
"""Base class for PyTorch FX Graph importer."""
import abc
+from functools import reduce
import math
from typing import Callable, Dict, Optional, Tuple, Union
@@ -1018,6 +1019,24 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
other_shape = self.shape_of(args[1]) # the shape of 'other'
return self.block_builder.emit(relax.op.broadcast_to(data,
other_shape))
+ def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var:
+ shape = self.shape_of(x)
+ start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
+ end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
+ flattened = reduce(lambda x, y: x * y, [shape[i] for i in
range(start_dim, end_dim + 1)])
+ new_shape = (
+ [shape[i] for i in range(0, start_dim)]
+ + [flattened]
+ + [shape[i] for i in range(end_dim + 1, len(shape))]
+ )
+ return self.block_builder.emit(relax.op.reshape(x, new_shape))
+
+ def _flatten(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ start_dim = node.args[1] if len(node.args) >= 2 else
node.kwargs.get("start_dim", 0)
+ end_dim = node.args[2] if len(node.args) == 3 else
node.kwargs.get("end_dim", -1)
+ return self._flatten_impl(x, start_dim, end_dim)
+
def _flip(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims",
None)
@@ -1233,6 +1252,21 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
)
)
+ ########## DataType ##########
+
+ def _to(self, node: fx.Node) -> relax.Var:
+ import torch
+
+ x = self.env[node.args[0]]
+ if len(node.args) == 2:
+ if isinstance(node.args[1], torch.dtype):
+ dtype = BaseFXGraphImporter._convert_data_type(node.args[1],
self.env)
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+ elif "dtype" in node.kwargs:
+ dtype =
BaseFXGraphImporter._convert_data_type(node.kwargs["dtype"], self.env)
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+ return x
+
########## Others ##########
def _getitem(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2e7c682aa3..26121ecdea 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -377,6 +377,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"cumprod.default": self._cumprod,
"expand.default": self._expand,
"expand_as.default": self._expand_as,
+ "flatten.using_ints": self._flatten,
"flip.default": self._flip,
"gather.default": self._gather,
"permute.default": self._permute,
@@ -411,6 +412,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"lift_fresh_copy.default": self._to_copy,
"new_ones.default": self._new_ones,
"one_hot.default": self._one_hot,
+ # datatype
+ "to.dtype": self._to,
+ "to.dtype_layout": self._to,
# other
"getitem": self._getitem,
}
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 3ddf919c2e..e79c1dbc48 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -415,24 +415,6 @@ class TorchFXImporter(BaseFXGraphImporter):
dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
return self.block_builder.emit(relax.op.split(x, chunks, dim))
- def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var:
- shape = self.shape_of(x)
- start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
- end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
- flattened = reduce(lambda x, y: x * y, [shape[i] for i in
range(start_dim, end_dim + 1)])
- new_shape = (
- [shape[i] for i in range(0, start_dim)]
- + [flattened]
- + [shape[i] for i in range(end_dim + 1, len(shape))]
- )
- return self.block_builder.emit(relax.op.reshape(x, new_shape))
-
- def _flatten(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- start_dim = node.args[1] if len(node.args) >= 2 else
node.kwargs.get("start_dim", 0)
- end_dim = node.args[2] if len(node.args) == 3 else
node.kwargs.get("end_dim", -1)
- return self._flatten_impl(x, start_dim, end_dim)
-
def _flatten_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 2175f9aa39..cc2f669d32 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1021,10 +1021,6 @@ def test_binary3():
verify_model(Min1(), example_args1, {}, expected_min1)
[email protected](
- version.parse(torch_version) >= version.parse("2.6.0"),
- reason="Tests not compatible with PyTorch >= 2.6",
-)
def test_batchnorm2d():
class BatchNorm2d(Module):
def __init__(self):
@@ -2702,10 +2698,6 @@ def test_expand():
verify_model(Expand2(), example_args, {}, expected1)
[email protected](
- version.parse(torch_version) >= version.parse("2.6.0"),
- reason="Tests not compatible with PyTorch >= 2.6",
-)
def test_flatten():
class Flatten(Module):
def __init__(self):
@@ -2907,10 +2899,6 @@ def test_select_slice():
verify_model(Slice2(), example_args, {}, expected2)
[email protected](
- version.parse(torch_version) >= version.parse("2.6.0"),
- reason="Tests not compatible with PyTorch >= 2.6",
-)
def test_split():
class Chunk(Module):
def forward(self, input):
@@ -3340,10 +3328,6 @@ def test_new_ones():
verify_model(NewOnes(), example_args, {}, expected1)
[email protected](
- version.parse(torch_version) >= version.parse("2.6.0"),
- reason="Tests not compatible with PyTorch >= 2.6",
-)
def test_to_copy():
# float
class ToFloat(Module):
@@ -3394,7 +3378,8 @@ def test_to_copy():
) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
# block 0
with R.dataflow():
- gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (x,)
+ lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x,
dtype="float32")
+ gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv