caiyaodeng opened a new issue, #18541:
URL: https://github.com/apache/tvm/issues/18541
## 问题描述
在使用TVM 0.22.0版本导入ONNX模型(特别是使用opset
14的BGE模型)时,遇到了Attention操作符转换失败的问题。错误显示mask
index形状不符合要求,即使已经尝试将所有mask相关输入设置为(batch_size, seq_len)的2D形状。
## 环境信息
- TVM版本:0.22.0
- ONNX模型:BGE模型,opset 14
- 操作系统:Linux
## 错误详情
### 完整错误堆栈
```
Error converting operator Attention, with inputs: [lv8, metadata["relax.expr.
Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it.,
metadata
["relax.expr.Constant"][0]
# Metadata omitted. Use show_meta=True in script() method to show it., lv10]
模型处理错误: mask index should be in shape of (batch_size, seq_len),
or (batch_size, seq_len, seq_len)
Traceback (most recent call last):
File "/home/cyd/workspace/tvm/my_test/optimize_model_v3.py", line 85, in
load_and_optimize_onnx_model
relax_module, params = from_onnx(onnx_model, shape_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/home/cyd/workspace/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py",
line 4235, in from_onnx
return g.from_onnx(graph, opset)
^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/home/cyd/workspace/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py",
line 3865, in from_onnx
self._construct_nodes(graph)
File
"/home/cyd/workspace/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py",
line 4046, in _construct_nodes
raise err
File
"/home/cyd/workspace/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py",
line 4041, in _construct_nodes
op = self._convert_operator(op_name, inputs, attr, self.opset)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/home/cyd/workspace/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py",
line 4141, in _convert_operator
sym = op_function(self.bb, inputs, attrs, [self._nodes, self._params])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/home/cyd/workspace/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py",
line 2109, in _impl_v1
assert mask_index_shape in (
^^^^^^^^^^^^^^^^^^^^^
AssertionError: mask index should be in shape of (batch_size, seq_len),
or (batch_size, seq_len, seq_len)
```
### 关键错误点
1. 断言失败发生在 onnx_frontend.py 第2109行,错误消息为:"mask index should be in shape of
(batch_size, seq_len), or (batch_size, seq_len, seq_len)"
2. 即使已经将所有mask相关输入设置为2D形状 (batch_size, seq_len),仍然无法通过断言
## 复现步骤
1. 准备一个使用opset 14的ONNX模型(如BGE embedding模型)
2. 使用tvm.relax.frontend.onnx.from_onnx导入模型
3. 为模型输入设置形状字典,包括将所有mask相关输入设置为2D形状:
```
shape_dict = {
'input_ids': (1, 512),
'attention_mask': (1, 512),
'dummy_mask_index': (1, 512),
'EmbedLayerNormalization_0_dummy_mask_index': (1, 512)
}
```
4. 尝试转换模型时遇到断言错误
## 预期行为
TVM应该能够正确处理符合标准形状要求的mask索引输入,或者提供更清晰的错误信息和解决方案。
## 可能的解决方案
1. 检查 _impl_v1 函数中的断言逻辑,确保它正确处理模型中传递的mask索引形状
2. 考虑支持更多类型的mask索引形状,特别是对于opset 14及以上版本的ONNX模型
3. 提供更明确的文档说明TVM支持的mask索引形状格式
## 附加信息
该问题在使用ONNX Runtime作为替代方案时不会出现,ONNX Runtime能够完全支持此类模型。
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]