This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cdebac1961 [Relax][PyTorch] Fix torch 2.6 compatibility issues (#17807)
cdebac1961 is described below

commit cdebac19617c61f00b7296d0951136f8dfdafe50
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Apr 4 13:16:43 2025 +0900

    [Relax][PyTorch] Fix torch 2.6 compatibility issues (#17807)
    
    * fix test_flatten
    
    * re-enable test_split
    
    * fix test_to_copy
    
    * re-enable test_batchnorm2d
---
 .../frontend/torch/base_fx_graph_translator.py     | 34 ++++++++++++++++++++++
 .../frontend/torch/exported_program_translator.py  |  4 +++
 python/tvm/relax/frontend/torch/fx_translator.py   | 18 ------------
 .../relax/test_frontend_from_exported_program.py   | 19 ++----------
 4 files changed, 40 insertions(+), 35 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 890f925079..d99411bd56 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -19,6 +19,7 @@
 # pylint: disable=import-outside-toplevel
 """Base class for PyTorch FX Graph importer."""
 import abc
+from functools import reduce
 import math
 from typing import Callable, Dict, Optional, Tuple, Union
 
@@ -1018,6 +1019,24 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         other_shape = self.shape_of(args[1])  # the shape of 'other'
         return self.block_builder.emit(relax.op.broadcast_to(data, 
other_shape))
 
+    def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var:
+        shape = self.shape_of(x)
+        start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
+        end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
+        flattened = reduce(lambda x, y: x * y, [shape[i] for i in 
range(start_dim, end_dim + 1)])
+        new_shape = (
+            [shape[i] for i in range(0, start_dim)]
+            + [flattened]
+            + [shape[i] for i in range(end_dim + 1, len(shape))]
+        )
+        return self.block_builder.emit(relax.op.reshape(x, new_shape))
+
+    def _flatten(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        start_dim = node.args[1] if len(node.args) >= 2 else 
node.kwargs.get("start_dim", 0)
+        end_dim = node.args[2] if len(node.args) == 3 else 
node.kwargs.get("end_dim", -1)
+        return self._flatten_impl(x, start_dim, end_dim)
+
     def _flip(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", 
None)
@@ -1233,6 +1252,21 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             )
         )
 
+    ########## DataType ##########
+
+    def _to(self, node: fx.Node) -> relax.Var:
+        import torch
+
+        x = self.env[node.args[0]]
+        if len(node.args) == 2:
+            if isinstance(node.args[1], torch.dtype):
+                dtype = BaseFXGraphImporter._convert_data_type(node.args[1], 
self.env)
+                return self.block_builder.emit(relax.op.astype(x, dtype))
+        elif "dtype" in node.kwargs:
+            dtype = 
BaseFXGraphImporter._convert_data_type(node.kwargs["dtype"], self.env)
+            return self.block_builder.emit(relax.op.astype(x, dtype))
+        return x
+
     ########## Others ##########
 
     def _getitem(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2e7c682aa3..26121ecdea 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -377,6 +377,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "cumprod.default": self._cumprod,
             "expand.default": self._expand,
             "expand_as.default": self._expand_as,
+            "flatten.using_ints": self._flatten,
             "flip.default": self._flip,
             "gather.default": self._gather,
             "permute.default": self._permute,
@@ -411,6 +412,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "lift_fresh_copy.default": self._to_copy,
             "new_ones.default": self._new_ones,
             "one_hot.default": self._one_hot,
+            # datatype
+            "to.dtype": self._to,
+            "to.dtype_layout": self._to,
             # other
             "getitem": self._getitem,
         }
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 3ddf919c2e..e79c1dbc48 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -415,24 +415,6 @@ class TorchFXImporter(BaseFXGraphImporter):
         dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
         return self.block_builder.emit(relax.op.split(x, chunks, dim))
 
-    def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var:
-        shape = self.shape_of(x)
-        start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
-        end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
-        flattened = reduce(lambda x, y: x * y, [shape[i] for i in 
range(start_dim, end_dim + 1)])
-        new_shape = (
-            [shape[i] for i in range(0, start_dim)]
-            + [flattened]
-            + [shape[i] for i in range(end_dim + 1, len(shape))]
-        )
-        return self.block_builder.emit(relax.op.reshape(x, new_shape))
-
-    def _flatten(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        start_dim = node.args[1] if len(node.args) >= 2 else 
node.kwargs.get("start_dim", 0)
-        end_dim = node.args[2] if len(node.args) == 3 else 
node.kwargs.get("end_dim", -1)
-        return self._flatten_impl(x, start_dim, end_dim)
-
     def _flatten_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 2175f9aa39..cc2f669d32 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1021,10 +1021,6 @@ def test_binary3():
     verify_model(Min1(), example_args1, {}, expected_min1)
 
 
[email protected](
-    version.parse(torch_version) >= version.parse("2.6.0"),
-    reason="Tests not compatible with PyTorch >= 2.6",
-)
 def test_batchnorm2d():
     class BatchNorm2d(Module):
         def __init__(self):
@@ -2702,10 +2698,6 @@ def test_expand():
     verify_model(Expand2(), example_args, {}, expected1)
 
 
[email protected](
-    version.parse(torch_version) >= version.parse("2.6.0"),
-    reason="Tests not compatible with PyTorch >= 2.6",
-)
 def test_flatten():
     class Flatten(Module):
         def __init__(self):
@@ -2907,10 +2899,6 @@ def test_select_slice():
     verify_model(Slice2(), example_args, {}, expected2)
 
 
[email protected](
-    version.parse(torch_version) >= version.parse("2.6.0"),
-    reason="Tests not compatible with PyTorch >= 2.6",
-)
 def test_split():
     class Chunk(Module):
         def forward(self, input):
@@ -3340,10 +3328,6 @@ def test_new_ones():
     verify_model(NewOnes(), example_args, {}, expected1)
 
 
[email protected](
-    version.parse(torch_version) >= version.parse("2.6.0"),
-    reason="Tests not compatible with PyTorch >= 2.6",
-)
 def test_to_copy():
     # float
     class ToFloat(Module):
@@ -3394,7 +3378,8 @@ def test_to_copy():
         ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
             # block 0
             with R.dataflow():
-                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (x,)
+                lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, 
dtype="float32")
+                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 

Reply via email to