This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6e8c367fd7 [Relax] Add torch exported program ingestion capability for
Tensor.detach(), Tensor.copy_, and aten.lift_fresh_copy (#17723)
6e8c367fd7 is described below
commit 6e8c367fd72817f2f69701b5c6a71e41362e6ebc
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Mar 10 20:02:08 2025 -0400
[Relax] Add torch exported program ingestion capability for
Tensor.detach(), Tensor.copy_, and aten.lift_fresh_copy (#17723)
* detach and copy
* copy_ implemenation. Unit test passes
* restore test_frontend
* don't specify syspath
* todo for _detach()
* Black formatter
* black formatting with version 22.12.0
* cleanup unit tests and ran Black Formatter with version 22
* restore unmodified frontend test
* fix vm in assert_torch_output_vs_tvm_from_exported_to_cuda
* lint with Python Black formatter
* update todo
* update explanation for _detach
---
.../frontend/torch/base_fx_graph_translator.py | 12 ++++++++
.../frontend/torch/exported_program_translator.py | 4 +++
tests/python/relax/test_from_exported_to_cuda.py | 33 +++++++++++++++++++++-
3 files changed, 48 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 8b771b5d2f..d5cad2381b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -994,7 +994,19 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
########## Creation ##########
+ def _detach(self, node: fx.Node) -> relax.Var:
+ # There is no way to implement detach() such that the output shares
+ # the same memory as the input. In-place operations are not supported
+ # by the translator, and therefore we just return a copy of the input.
+ return self.env[node.args[0]]
+
+ def _copy_(self, node: fx.Node) -> relax.Var:
+ # Copies the source tensor's to the destination tensor
+ # In TVM, that means simply returning the source tensor
+ return self.env[node.args[1]]
+
def _to_copy(self, node: fx.Node) -> relax.Var:
+ # Returns a copy of the input tensor
import torch # type: ignore
x = self.env[node.args[0]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index f3c0a64676..4ff31ea1d7 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -295,6 +295,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# tensor manipulation
"cat.default": self._cat,
"concat.default": self._cat,
+ "copy_.default": self._copy_,
"cumsum.default": self._cumsum,
"expand.default": self._expand,
"permute.default": self._permute,
@@ -313,6 +314,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"reshape.default": self._reshape,
# tensor creation
"_to_copy.default": self._to_copy,
+ "lift_fresh_copy.default": self._to_copy,
+ "detach.default": self._detach,
+ "detach_.default": self._detach,
"arange.start": self._arange,
"contiguous.default": lambda node: self.env[node.args[0]], # no-op
"clone.default": lambda node: self.env[node.args[0]],
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
index 69daab36a5..bd4bdcf617 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -20,6 +20,7 @@ from tvm import relax
import tvm.testing
import numpy as np
import torch
+from torch import nn
from torch.export import export
from tvm.relax.frontend.torch import from_exported_program
from torch.nn import Softmax, Upsample
@@ -55,6 +56,24 @@ def
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5,
atol=1e-5)
[email protected]_targets("cuda")
+def test_copy_(target, dev):
+ class CopyTester(nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ self.register_buffer("buffer", torch.zeros(size))
+
+ def forward(self, x):
+ self.buffer.copy_(x)
+
+ return x * 3 + self.buffer * 5
+
+ size = (2, 2)
+ raw_data = np.random.rand(*size).astype(np.float32)
+ torch_module = CopyTester(size).eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
@tvm.testing.parametrize_targets("cuda")
def test_upsample_with_size(target, dev):
"""
@@ -72,6 +91,19 @@ def test_upsample_with_size(target, dev):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
[email protected]_targets("cuda")
+def test_detach_no_change(target, dev):
+ # In TVM, detach() is just identity
+ class DetachTester(nn.Module):
+ def forward(self, x):
+ detached = x.detach()
+ return detached
+
+ raw_data = np.ones((2, 2)).astype(np.float32)
+ torch_module = DetachTester().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
@tvm.testing.parametrize_targets("cuda")
def test_upsample_with_scale_factor(target, dev):
"""
@@ -87,7 +119,6 @@ def test_upsample_with_scale_factor(target, dev):
)
raw_data = np.random.rand(batch_size, channels, height,
width).astype("float32")
-
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)