This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0d5baacc02 [ONNX] Support SequenceErase op (#13865)
0d5baacc02 is described below
commit 0d5baacc0241253547fe5235f95416a6467e712c
Author: Valery Chernov <[email protected]>
AuthorDate: Tue Jan 31 11:06:02 2023 +0400
[ONNX] Support SequenceErase op (#13865)
* SequenceErase was implemented in ONNX front-end
* add SequenceErase node to Sequence test
* remark from reviewer. fix negative position recalculation
* add assert
---------
Co-authored-by: Valery Chernov <[email protected]>
---
python/tvm/relay/frontend/onnx.py | 42 ++++++++++++++++++++++++++----
tests/python/frontend/onnx/test_forward.py | 10 ++++++-
2 files changed, 46 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 6e0c7cc2dd..93429a8638 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -6148,13 +6148,35 @@ class SequenceConstruct(OnnxOpConverter):
return _expr.Tuple(inputs)
-class SequenceLength(OnnxOpConverter):
- """Operator converter for sequence length op."""
+class SequenceErase(OnnxOpConverter):
+ """Operator converter for sequence erase op."""
@classmethod
def _impl_v11(cls, inputs, attr, params):
- # Get length of input sequence
- return _expr.const(len(inputs[0]), dtype="int64")
+ # Erase tensor from sequence on specified position
+ input_sequence = inputs[0]
+
+ if len(inputs) == 2:
+ position = inputs[1]
+ # Non constant position is not supported.
+ if isinstance(position, _expr.Constant):
+ position = position.data.numpy()
+ elif position.name_hint in params:
+ position = params[position.name_hint].numpy()
+ else:
+ raise NotImplementedError("Position must be a constant.")
+ else:
+ position = -1
+
+ seq_len = len(input_sequence)
+ assert -seq_len <= position < seq_len, "Position is out of bounds"
+
+ if position < 0:
+ position = seq_len + position
+ # Convert sequence to a list, insert tensors before erased, and
repackage as Tuple.
+ tensor_list = [input_sequence[i] for i in range(seq_len) if i !=
position]
+ # Create new tuple and return.
+ return _expr.Tuple(tensor_list)
class SequenceInsert(OnnxOpConverter):
@@ -6188,6 +6210,15 @@ class SequenceInsert(OnnxOpConverter):
return _expr.Tuple(tensor_list)
+class SequenceLength(OnnxOpConverter):
+ """Operator converter for sequence length op."""
+
+ @classmethod
+ def _impl_v11(cls, inputs, attr, params):
+ # Get length of input sequence
+ return _expr.const(len(inputs[0]), dtype="int64")
+
+
class ConcatFromSequence(OnnxOpConverter):
"""Operator converter for sequence concatenation op."""
@@ -6492,8 +6523,9 @@ def _get_convert_map(opset):
"LinearRegressor": LinearRegressor.get_converter(opset),
# Sequence operators
"SequenceConstruct": SequenceConstruct.get_converter(opset),
- "SequenceLength": SequenceLength.get_converter(opset),
+ "SequenceErase": SequenceErase.get_converter(opset),
"SequenceInsert": SequenceInsert.get_converter(opset),
+ "SequenceLength": SequenceLength.get_converter(opset),
"ConcatFromSequence": ConcatFromSequence.get_converter(opset),
"SplitToSequence": SplitToSequence.get_converter(opset),
"SequenceAt": SequenceAt.get_converter(opset),
diff --git a/tests/python/frontend/onnx/test_forward.py
b/tests/python/frontend/onnx/test_forward.py
index 6a780a632f..3e1af40867 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -7747,10 +7747,17 @@ def test_sequence(target, dev):
outputs=["inserted_sequence"],
)
+ # Test sequence erase.
+ erase_node = helper.make_node(
+ "SequenceErase",
+ inputs=["inserted_sequence", "position"],
+ outputs=["erased_sequence"],
+ )
+
# Test sequence concatenation.
concat_node = helper.make_node(
"ConcatFromSequence",
- inputs=["inserted_sequence"],
+ inputs=["erased_sequence"],
outputs=["concat_sequence"],
axis=axis,
)
@@ -7796,6 +7803,7 @@ def test_sequence(target, dev):
position_node,
construct_node,
insert_node,
+ erase_node,
concat_node,
split_node,
at_node,