This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6e8c367fd7 [Relax] Add torch exported program ingestion capability for 
Tensor.detach(), Tensor.copy_, and aten.lift_fresh_copy (#17723)
6e8c367fd7 is described below

commit 6e8c367fd72817f2f69701b5c6a71e41362e6ebc
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Mar 10 20:02:08 2025 -0400

    [Relax] Add torch exported program ingestion capability for 
Tensor.detach(), Tensor.copy_, and aten.lift_fresh_copy (#17723)
    
    * detach and copy
    
    * copy_ implemenation. Unit test passes
    
    * restore test_frontend
    
    * don't specify syspath
    
    * todo for _detach()
    
    * Black formatter
    
    * black formatting with version 22.12.0
    
    * cleanup unit tests and ran Black Formatter with version 22
    
    * restore unmodified frontend test
    
    * fix vm in assert_torch_output_vs_tvm_from_exported_to_cuda
    
    * lint with Python Black formatter
    
    * update todo
    
    * update explanation for _detach
---
 .../frontend/torch/base_fx_graph_translator.py     | 12 ++++++++
 .../frontend/torch/exported_program_translator.py  |  4 +++
 tests/python/relax/test_from_exported_to_cuda.py   | 33 +++++++++++++++++++++-
 3 files changed, 48 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 8b771b5d2f..d5cad2381b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -994,7 +994,19 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
     ########## Creation ##########
 
+    def _detach(self, node: fx.Node) -> relax.Var:
+        # There is no way to implement detach() such that the output shares
+        # the same memory as the input. In-place operations are not supported
+        # by the translator, and therefore we just return a copy of the input.
+        return self.env[node.args[0]]
+
+    def _copy_(self, node: fx.Node) -> relax.Var:
+        # Copies the source tensor's to the destination tensor
+        # In TVM, that means simply returning the source tensor
+        return self.env[node.args[1]]
+
     def _to_copy(self, node: fx.Node) -> relax.Var:
+        # Returns a copy of the input tensor
         import torch  # type: ignore
 
         x = self.env[node.args[0]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index f3c0a64676..4ff31ea1d7 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -295,6 +295,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             # tensor manipulation
             "cat.default": self._cat,
             "concat.default": self._cat,
+            "copy_.default": self._copy_,
             "cumsum.default": self._cumsum,
             "expand.default": self._expand,
             "permute.default": self._permute,
@@ -313,6 +314,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "reshape.default": self._reshape,
             # tensor creation
             "_to_copy.default": self._to_copy,
+            "lift_fresh_copy.default": self._to_copy,
+            "detach.default": self._detach,
+            "detach_.default": self._detach,
             "arange.start": self._arange,
             "contiguous.default": lambda node: self.env[node.args[0]],  # no-op
             "clone.default": lambda node: self.env[node.args[0]],
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index 69daab36a5..bd4bdcf617 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -20,6 +20,7 @@ from tvm import relax
 import tvm.testing
 import numpy as np
 import torch
+from torch import nn
 from torch.export import export
 from tvm.relax.frontend.torch import from_exported_program
 from torch.nn import Softmax, Upsample
@@ -55,6 +56,24 @@ def 
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
     np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, 
atol=1e-5)
 
 
[email protected]_targets("cuda")
+def test_copy_(target, dev):
+    class CopyTester(nn.Module):
+        def __init__(self, size):
+            super().__init__()
+            self.register_buffer("buffer", torch.zeros(size))
+
+        def forward(self, x):
+            self.buffer.copy_(x)
+
+            return x * 3 + self.buffer * 5
+
+    size = (2, 2)
+    raw_data = np.random.rand(*size).astype(np.float32)
+    torch_module = CopyTester(size).eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_upsample_with_size(target, dev):
     """
@@ -72,6 +91,19 @@ def test_upsample_with_size(target, dev):
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 
[email protected]_targets("cuda")
+def test_detach_no_change(target, dev):
+    # In TVM, detach() is just identity
+    class DetachTester(nn.Module):
+        def forward(self, x):
+            detached = x.detach()
+            return detached
+
+    raw_data = np.ones((2, 2)).astype(np.float32)
+    torch_module = DetachTester().eval()
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_upsample_with_scale_factor(target, dev):
     """
@@ -87,7 +119,6 @@ def test_upsample_with_scale_factor(target, dev):
     )
 
     raw_data = np.random.rand(batch_size, channels, height, 
width).astype("float32")
-
     assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
 
 

Reply via email to