This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c5759738db [Relax][Onnx][BatchNorm] Pass momentum and training_mode 
into BatchNorm Operator (#18704)
c5759738db is described below

commit c5759738dba668509cce383d5c112f593c0806ed
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Mon Feb 2 23:49:13 2026 +0700

    [Relax][Onnx][BatchNorm] Pass momentum and training_mode into BatchNorm 
Operator (#18704)
    
    ### Description
    - Onnx model have training_mode atrr = False, but Relax model after
    convert have training = True
    - Momentum values ​​in Relax module are not the same as onnx model
    
    ### Steps to Reproduce
    <img width="600" height="400" alt="BatchNorm"
    
src="https://github.com/user-attachments/assets/2f0ca26b-e83b-4ab8-ab06-a537802af6de";
    />
    
    - Relax model:
    ```
    class Module:
        def main(X: R.Tensor((2, 3, 4, 4), dtype="float32")) -> R.Tensor((2, 3, 
4, 4), dtype="float32"):
            R.func_attr({"num_input": 1})
            with R.dataflow():
                lv: R.Tuple(R.Tensor((2, 3, 4, 4), dtype="float32"), 
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) = 
R.nn.batch_norm(X, metadata["relax.expr.Constant"][0], 
metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2], 
metadata["relax.expr.Constant"][3], axis=1, epsilon=9.9999997473787516e-06, 
center=True, scale=True, momentum=0.10000000000000001, training=True)
                lv1: R.Tensor((2, 3, 4, 4), dtype="float32") = lv[0]
                lv2: R.Tensor((3,), dtype="float32") = lv[1]
                lv3: R.Tensor((3,), dtype="float32") = lv[2]
                gv: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1
                R.output(gv)
            return gv
    ```
    
    ### Resolved
    - Get Attributes and Pass momentum/training_mode with default value into
    BatchNorm Operator
    - Fixed: #18703
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 784be639dd..61ab45d308 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2435,8 +2435,18 @@ class BatchNormalization(OnnxOpConverter):
         mean = inputs[3]
         var = inputs[4]
         epsilon = attr.get("epsilon", 1e-05)
+        momentum = attr.get("momentum", 0.9)
+        training_mode = attr.get("training_mode", 0)
         return relax.op.nn.batch_norm(
-            data, gamma=scale, beta=bias, moving_mean=mean, moving_var=var, 
epsilon=epsilon, axis=1
+            data,
+            gamma=scale,
+            beta=bias,
+            moving_mean=mean,
+            moving_var=var,
+            axis=1,
+            epsilon=epsilon,
+            momentum=momentum,
+            training=training_mode,
         )
 
 

Reply via email to