This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f3f691a05 [Relax][PyTorch] Simplify tensor args conversion in Dynamo 
(#18726)
6f3f691a05 is described below

commit 6f3f691a05ab230e0f7674bf9f0522d11f269140
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Feb 8 21:29:48 2026 +0900

    [Relax][PyTorch] Simplify tensor args conversion in Dynamo (#18726)
    
    As per title.
---
 python/tvm/relax/frontend/torch/dynamo.py | 13 ++++---------
 1 file changed, 4 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/dynamo.py 
b/python/tvm/relax/frontend/torch/dynamo.py
index dea08256de..21388dbef7 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -64,14 +64,6 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = 
None):
             else:
                 raise ValueError(f"Unsupported type {type(nd_tensor)}")
 
-        def to_tvm_tensor(torch_tensor):
-            """A helper function to transfer a torch.tensor to Tensor."""
-            if not isinstance(torch_tensor, 
torch._subclasses.fake_tensor.FakeTensor):
-                return tvm.runtime.tensor(torch_tensor.numpy())
-            # Fake Tensor
-            real_tensor = torch.randn(torch_tensor.shape, 
dtype=torch_tensor.dtype)
-            return tvm.runtime.tensor(real_tensor.numpy())
-
         graph_module.graph.eliminate_dead_code()
 
         device = device_from_inputs(example_inputs)
@@ -139,7 +131,10 @@ def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = 
None):
             for arg in args:
                 if arg.requires_grad:
                     arg = arg.detach()
-                vm_args.append(to_tvm_tensor(arg))
+                if isinstance(arg, torch._subclasses.fake_tensor.FakeTensor):
+                    # Materialize a real (eager) Tensor
+                    arg = torch.randn(arg.shape, dtype=arg.dtype, 
device=device)
+                vm_args.append(arg)
             outputs = vm["main"](*vm_args)
             return to_torch_tensor(outputs)
 

Reply via email to