This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6d46fa7d1e [Relax][PyTroch] Add randn.default and randn_like.default
support (#18815)
6d46fa7d1e is described below
commit 6d46fa7d1ee992e46b60421cff7b3e49db389bb7
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Feb 24 21:28:52 2026 +0800
[Relax][PyTroch] Add randn.default and randn_like.default support (#18815)
Why
PyTorch models using torch.randn() or torch.randn_like() fail to convert
via from_exported_program (part of #18476).
How
- Add _randn and _randn_like handlers that emit constant tensors sampled
- Add tests verifying conversion produces correct output shape and dtype
Signed-off-by: Guan-Ming Chiu <[email protected]>
---
.../frontend/torch/exported_program_translator.py | 23 +++++++++++++++++
.../relax/test_frontend_from_exported_program.py | 30 ++++++++++++++++++++++
2 files changed, 53 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 233cba8df9..39595a9f00 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -991,6 +991,27 @@ class ExportedProgramImporter(BaseFXGraphImporter):
relax.op.hamming_window(window_size, periodic, alpha, beta, dtype)
)
+ def _randn(self, node: fx.Node) -> relax.Var:
+ import numpy as np
+
+ args = self.retrieve_args(node)
+ size = args[0] if isinstance(args[0], (list, tuple)) else (args[0],)
+ dtype = self._convert_data_type(
+ node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+ )
+ data = np.random.randn(*size).astype(dtype)
+ return self.block_builder.emit(relax.const(data, dtype))
+
+ def _randn_like(self, node: fx.Node) -> relax.Var:
+ import numpy as np
+
+ x = self.env[node.args[0]]
+ x_sinfo = x.struct_info
+ shape = [int(s) for s in x_sinfo.shape]
+ dtype = self._convert_data_type(node.kwargs.get("dtype", None) or
x_sinfo.dtype, self.env)
+ data = np.random.randn(*shape).astype(dtype)
+ return self.block_builder.emit(relax.const(data, dtype))
+
def _zeros(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple))
else (args[0],))
@@ -1484,6 +1505,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"new_zeros.default": self._new_zeros,
"one_hot.default": self._one_hot,
"ones.default": self._ones,
+ "randn.default": self._randn,
+ "randn_like.default": self._randn_like,
"ones_like.default": lambda node: self.block_builder.emit(
relax.op.ones_like(self.env[node.args[0]])
),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 0b5cb1f777..6bab158c08 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6796,6 +6796,36 @@ def test_zeros_like():
verify_model(ZerosLike(), example_args, {}, Expected)
+def test_randn():
+ class Randn(Module):
+ def forward(self, input):
+ return input + torch.randn(5, 3)
+
+ example_args = (torch.rand(5, 3, dtype=torch.float32),)
+ exported_program = export(Randn(), args=example_args)
+ mod = from_exported_program(exported_program)
+ func = mod["main"]
+ ret_sinfo = func.ret_struct_info
+ assert ret_sinfo.fields[0].shape[0] == 5
+ assert ret_sinfo.fields[0].shape[1] == 3
+ assert ret_sinfo.fields[0].dtype == "float32"
+
+
+def test_randn_like():
+ class RandnLike(Module):
+ def forward(self, input):
+ return input + torch.randn_like(input)
+
+ example_args = (torch.rand(4, 6, dtype=torch.float32),)
+ exported_program = export(RandnLike(), args=example_args)
+ mod = from_exported_program(exported_program)
+ func = mod["main"]
+ ret_sinfo = func.ret_struct_info
+ assert ret_sinfo.fields[0].shape[0] == 4
+ assert ret_sinfo.fields[0].shape[1] == 6
+ assert ret_sinfo.fields[0].dtype == "float32"
+
+
def test_type_as():
class TypeAs(Module):
def forward(self, input, other):