This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d83cd217a5 [microNPU][ETHOSU] Fix ConcatRewriter args processing 
(#16003)
d83cd217a5 is described below

commit d83cd217a59e0dfbb83012138de0b29d2c66fafa
Author: Aleksei-grovety <113356454+aleksei-grov...@users.noreply.github.com>
AuthorDate: Tue Oct 31 17:15:47 2023 +0400

    [microNPU][ETHOSU] Fix ConcatRewriter args processing (#16003)
    
    In ConcatRewriter the case was not considered when the concatenation 
argument is TupleGetItem.
---
 python/tvm/relay/backend/contrib/ethosu/legalize.py |  2 +-
 tests/python/contrib/test_ethosu/test_codegen.py    | 16 ++++++++++++++++
 2 files changed, 17 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py 
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index 2806ef8a46..7ed69e1e9b 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -1184,7 +1184,7 @@ class ConcatRewriter(DFPatternCallback):
         # Find the tensors that are inputs to the concat and the scales and 
zero points
         concat_args = list()
         for arg in post.args:
-            if isinstance(arg, tvm.relay.expr.Call):
+            if isinstance(arg, (tvm.relay.expr.Call, 
tvm.relay.expr.TupleGetItem)):
                 concat_args.append(arg)
 
         axis = post.op.body.attrs.axis
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py 
b/tests/python/contrib/test_ethosu/test_codegen.py
index 66809a775f..f69c114cab 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -1170,6 +1170,22 @@ def test_tflite_concat(shapes, axis, accel_type):
     infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, 
enable_cascader=False)
 
 
+def test_tflite_unstack_concat():
+    np.random.seed(0)
+    shapes = [(2, 4, 16)]
+    axis = 1
+    accel_type = "ethos-u55-256"
+
+    @tf.function
+    def concat_func(input):
+        inputs = tf.unstack(input)
+        inputs.reverse()
+        op = tf.concat(inputs, axis)
+        return op
+
+    infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, 
enable_cascader=False)
+
+
 def test_tflite_concat_with_reused_args():
     np.random.seed(0)
     shapes = [(1, 1, 24, 1), (1, 1, 24, 1), (1, 1, 10, 1), (1, 1, 68, 1)]

Reply via email to