This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 1aecc9b8b6 [Unity][Frontend][Onnx] Simplify gemm (#15458) 1aecc9b8b6 is described below commit 1aecc9b8b622dbd870411c60d96a4843d1b75f4e Author: Josh Fromm <jwfr...@octoml.ai> AuthorDate: Tue Aug 1 21:13:02 2023 -0700 [Unity][Frontend][Onnx] Simplify gemm (#15458) Add simplification check to onnx gemm importer --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 10 +++++----- tests/python/relax/test_frontend_onnx.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 74eb904c4f..9a93f395ec 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -374,18 +374,18 @@ class Gemm(OnnxOpConverter): # Compute Y = alpha * A X B + beta * C - if alpha is not None: - A = bb.normalize(relax.op.multiply(A, relax.const(alpha, dtype=dtype))) + if alpha is not None and alpha != 1.0: + A = relax.op.multiply(A, relax.const(alpha, dtype=dtype)) if transA: A = relax.op.permute_dims(A, [1, 0]) if transB: B = relax.op.permute_dims(B, [1, 0]) - Y = bb.normalize(relax.op.matmul(A, B)) + Y = relax.op.matmul(A, B) if C is not None: - if beta is not None: - C = bb.normalize(relax.op.multiply(C, relax.const(beta, dtype=dtype))) + if beta is not None and beta != 1.0: + C = relax.op.multiply(C, relax.const(beta, dtype=dtype)) Y = relax.op.add(Y, C) return Y diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 3467e5bba2..647e72f04a 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -333,8 +333,8 @@ def test_gather(): _verify_gather([3, 3], [[0, 2]], [3, 1, 2], 1) -@pytest.mark.parametrize("alpha", [None, 0.25]) -@pytest.mark.parametrize("beta", [None, 0.35]) +@pytest.mark.parametrize("alpha", [None, 0.25, 1.0]) +@pytest.mark.parametrize("beta", [None, 0.35, 1.0]) @pytest.mark.parametrize("useC", [False, True]) def test_gemm(alpha, beta, useC): if useC: