This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d5baacc02 [ONNX] Support SequenceErase op (#13865)
0d5baacc02 is described below

commit 0d5baacc0241253547fe5235f95416a6467e712c
Author: Valery Chernov <[email protected]>
AuthorDate: Tue Jan 31 11:06:02 2023 +0400

    [ONNX] Support SequenceErase op (#13865)
    
    * SequenceErase was implemented in ONNX front-end
    
    * add SequenceErase node to Sequence test
    
    * remark from reviewer. fix negative position recalculation
    
    * add assert
    
    ---------
    
    Co-authored-by: Valery Chernov <[email protected]>
---
 python/tvm/relay/frontend/onnx.py          | 42 ++++++++++++++++++++++++++----
 tests/python/frontend/onnx/test_forward.py | 10 ++++++-
 2 files changed, 46 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 6e0c7cc2dd..93429a8638 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -6148,13 +6148,35 @@ class SequenceConstruct(OnnxOpConverter):
         return _expr.Tuple(inputs)
 
 
-class SequenceLength(OnnxOpConverter):
-    """Operator converter for sequence length op."""
+class SequenceErase(OnnxOpConverter):
+    """Operator converter for sequence erase op."""
 
     @classmethod
     def _impl_v11(cls, inputs, attr, params):
-        # Get length of input sequence
-        return _expr.const(len(inputs[0]), dtype="int64")
+        # Erase tensor from sequence on specified position
+        input_sequence = inputs[0]
+
+        if len(inputs) == 2:
+            position = inputs[1]
+            # Non constant position is not supported.
+            if isinstance(position, _expr.Constant):
+                position = position.data.numpy()
+            elif position.name_hint in params:
+                position = params[position.name_hint].numpy()
+            else:
+                raise NotImplementedError("Position must be a constant.")
+        else:
+            position = -1
+
+        seq_len = len(input_sequence)
+        assert -seq_len <= position < seq_len, "Position is out of bounds"
+
+        if position < 0:
+            position = seq_len + position
+        # Convert sequence to a list, insert tensors before erased, and 
repackage as Tuple.
+        tensor_list = [input_sequence[i] for i in range(seq_len) if i != 
position]
+        # Create new tuple and return.
+        return _expr.Tuple(tensor_list)
 
 
 class SequenceInsert(OnnxOpConverter):
@@ -6188,6 +6210,15 @@ class SequenceInsert(OnnxOpConverter):
         return _expr.Tuple(tensor_list)
 
 
+class SequenceLength(OnnxOpConverter):
+    """Operator converter for sequence length op."""
+
+    @classmethod
+    def _impl_v11(cls, inputs, attr, params):
+        # Get length of input sequence
+        return _expr.const(len(inputs[0]), dtype="int64")
+
+
 class ConcatFromSequence(OnnxOpConverter):
     """Operator converter for sequence concatenation op."""
 
@@ -6492,8 +6523,9 @@ def _get_convert_map(opset):
         "LinearRegressor": LinearRegressor.get_converter(opset),
         # Sequence operators
         "SequenceConstruct": SequenceConstruct.get_converter(opset),
-        "SequenceLength": SequenceLength.get_converter(opset),
+        "SequenceErase": SequenceErase.get_converter(opset),
         "SequenceInsert": SequenceInsert.get_converter(opset),
+        "SequenceLength": SequenceLength.get_converter(opset),
         "ConcatFromSequence": ConcatFromSequence.get_converter(opset),
         "SplitToSequence": SplitToSequence.get_converter(opset),
         "SequenceAt": SequenceAt.get_converter(opset),
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 6a780a632f..3e1af40867 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -7747,10 +7747,17 @@ def test_sequence(target, dev):
             outputs=["inserted_sequence"],
         )
 
+        # Test sequence erase.
+        erase_node = helper.make_node(
+            "SequenceErase",
+            inputs=["inserted_sequence", "position"],
+            outputs=["erased_sequence"],
+        )
+
         # Test sequence concatenation.
         concat_node = helper.make_node(
             "ConcatFromSequence",
-            inputs=["inserted_sequence"],
+            inputs=["erased_sequence"],
             outputs=["concat_sequence"],
             axis=axis,
         )
@@ -7796,6 +7803,7 @@ def test_sequence(target, dev):
             position_node,
             construct_node,
             insert_node,
+            erase_node,
             concat_node,
             split_node,
             at_node,

Reply via email to