This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fd4a08d5c5 [Relax][Frontend][ONNX] Fix `FastGelu` when bias does not
set (#18358)
fd4a08d5c5 is described below
commit fd4a08d5c5b78c03d5363734b7540ef3ffdcb8fe
Author: Neo Chien <[email protected]>
AuthorDate: Tue Oct 7 23:20:41 2025 +0800
[Relax][Frontend][ONNX] Fix `FastGelu` when bias does not set (#18358)
* [#17877][Relax][Frontend][ONNX] Fix when bias does not set
* [#17877][FRONTEND][ONNX] Fix Error converting operator FastGelu, with
inputs: [x, bias]
* [#17877][FRONTEND][ONNX] Fix Warning: Detected pow(x, y) where y >= 3, it
is recommended to avoid
* [#17877][FRONTEND][ONNX] Fix tvm.error.InternalError: Check failed: (ptr)
is false: The struct_info is not populated, check if you have normalized the
expr
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 17 ++++++++------
tests/python/relax/test_frontend_onnx.py | 30 +++++++++++++++++++++++++
2 files changed, 40 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 7432967c29..3b94ba1d66 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1155,11 +1155,12 @@ class FastGelu(OnnxOpConverter):
@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
- if inputs[1]:
+ x = inputs[0]
+ if len(inputs) > 1 and inputs[1] is not None:
bias = inputs[1]
bias_shape = bias.struct_info.shape
assert len(bias_shape) == 1, "bias term must be a 1D tensor"
- x += bias
+ x = bb.emit(relax.op.add(x, bias))
# Declare consts
const_dtype = x.struct_info.dtype
@@ -1169,11 +1170,13 @@ class FastGelu(OnnxOpConverter):
const2 = relax.const(0.044715 * math.sqrt(2 / math.pi),
dtype=const_dtype)
# Compute FastGelu
- term1 = relax.op.multiply(half, x)
- term2 = relax.op.multiply(const1, x)
- term3 = relax.op.multiply(const2, relax.op.power(x, relax.const(3,
const_dtype)))
- tanh = relax.op.tanh(relax.op.add(term2, term3))
- return relax.op.multiply(term1, relax.op.add(one, tanh))
+ term1 = bb.emit(relax.op.multiply(half, x))
+ term2 = bb.emit(relax.op.multiply(const1, x))
+ # use x^3 = x * x * x instead of pow(x, 3) for better performance
+ x_cubed = bb.emit(relax.op.multiply(relax.op.multiply(x, x), x))
+ term3 = bb.emit(relax.op.multiply(const2, x_cubed))
+ tanh = bb.emit(relax.op.tanh(relax.op.add(term2, term3)))
+ return bb.emit(relax.op.multiply(term1, relax.op.add(one, tanh)))
class BiasGelu(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index e4960e5b1a..a8d434e894 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -828,6 +828,36 @@ def test_bias_gelu():
verify_binary("BiasGelu", [32, 32], [32], [32, 32], domain="com.microsoft")
+def test_fast_gelu():
+ """Test FastGelu with and without bias"""
+ # Test FastGelu without bias
+ fast_gelu_node = helper.make_node("FastGelu", ["x"], ["y"],
domain="com.microsoft")
+ graph = helper.make_graph(
+ [fast_gelu_node],
+ "fast_gelu_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [32,
32])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32,
32])],
+ )
+ model = helper.make_model(graph, producer_name="fast_gelu_test")
+ check_correctness(model)
+
+ # Test FastGelu with bias
+ fast_gelu_with_bias_node = helper.make_node(
+ "FastGelu", ["x", "bias"], ["y"], domain="com.microsoft"
+ )
+ graph_with_bias = helper.make_graph(
+ [fast_gelu_with_bias_node],
+ "fast_gelu_with_bias_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, [32, 32]),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT, [32]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32,
32])],
+ )
+ model_with_bias = helper.make_model(graph_with_bias,
producer_name="fast_gelu_with_bias_test")
+ check_correctness(model_with_bias)
+
+
def test_where():
where_node = helper.make_node("Where", ["a", "b", "c"], ["d"])