This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d46fa7d1e [Relax][PyTroch] Add randn.default and randn_like.default 
support (#18815)
6d46fa7d1e is described below

commit 6d46fa7d1ee992e46b60421cff7b3e49db389bb7
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Feb 24 21:28:52 2026 +0800

    [Relax][PyTroch] Add randn.default and randn_like.default support (#18815)
    
    Why
    PyTorch models using torch.randn() or torch.randn_like() fail to convert
    via from_exported_program (part of #18476).
    How
    
    - Add _randn and _randn_like handlers that emit constant tensors sampled
    - Add tests verifying conversion produces correct output shape and dtype
    
    Signed-off-by: Guan-Ming Chiu <[email protected]>
---
 .../frontend/torch/exported_program_translator.py  | 23 +++++++++++++++++
 .../relax/test_frontend_from_exported_program.py   | 30 ++++++++++++++++++++++
 2 files changed, 53 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 233cba8df9..39595a9f00 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -991,6 +991,27 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             relax.op.hamming_window(window_size, periodic, alpha, beta, dtype)
         )
 
+    def _randn(self, node: fx.Node) -> relax.Var:
+        import numpy as np
+
+        args = self.retrieve_args(node)
+        size = args[0] if isinstance(args[0], (list, tuple)) else (args[0],)
+        dtype = self._convert_data_type(
+            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+        )
+        data = np.random.randn(*size).astype(dtype)
+        return self.block_builder.emit(relax.const(data, dtype))
+
+    def _randn_like(self, node: fx.Node) -> relax.Var:
+        import numpy as np
+
+        x = self.env[node.args[0]]
+        x_sinfo = x.struct_info
+        shape = [int(s) for s in x_sinfo.shape]
+        dtype = self._convert_data_type(node.kwargs.get("dtype", None) or 
x_sinfo.dtype, self.env)
+        data = np.random.randn(*shape).astype(dtype)
+        return self.block_builder.emit(relax.const(data, dtype))
+
     def _zeros(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) 
else (args[0],))
@@ -1484,6 +1505,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "new_zeros.default": self._new_zeros,
             "one_hot.default": self._one_hot,
             "ones.default": self._ones,
+            "randn.default": self._randn,
+            "randn_like.default": self._randn_like,
             "ones_like.default": lambda node: self.block_builder.emit(
                 relax.op.ones_like(self.env[node.args[0]])
             ),
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 0b5cb1f777..6bab158c08 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6796,6 +6796,36 @@ def test_zeros_like():
     verify_model(ZerosLike(), example_args, {}, Expected)
 
 
+def test_randn():
+    class Randn(Module):
+        def forward(self, input):
+            return input + torch.randn(5, 3)
+
+    example_args = (torch.rand(5, 3, dtype=torch.float32),)
+    exported_program = export(Randn(), args=example_args)
+    mod = from_exported_program(exported_program)
+    func = mod["main"]
+    ret_sinfo = func.ret_struct_info
+    assert ret_sinfo.fields[0].shape[0] == 5
+    assert ret_sinfo.fields[0].shape[1] == 3
+    assert ret_sinfo.fields[0].dtype == "float32"
+
+
+def test_randn_like():
+    class RandnLike(Module):
+        def forward(self, input):
+            return input + torch.randn_like(input)
+
+    example_args = (torch.rand(4, 6, dtype=torch.float32),)
+    exported_program = export(RandnLike(), args=example_args)
+    mod = from_exported_program(exported_program)
+    func = mod["main"]
+    ret_sinfo = func.ret_struct_info
+    assert ret_sinfo.fields[0].shape[0] == 4
+    assert ret_sinfo.fields[0].shape[1] == 6
+    assert ret_sinfo.fields[0].dtype == "float32"
+
+
 def test_type_as():
     class TypeAs(Module):
         def forward(self, input, other):

Reply via email to