This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e6538517b0 [Relax][TFLite] Add gather frontend expected IRModule tests 
(#19516)
e6538517b0 is described below

commit e6538517b016e5788bec375a36c83fa8c246f400
Author: Wei-Cheng Hsu <[email protected]>
AuthorDate: Fri May 8 16:08:43 2026 +0800

    [Relax][TFLite] Add gather frontend expected IRModule tests (#19516)
    
    This adds explicit Expected IRModule coverage for TFLite GATHER and
    GATHER_ND frontend conversion.
    
    GATHER_ND uses Relax gather_nd with int64 indices, so the frontend now
    casts int32 TFLite indices to int64 before emitting the Relax op. This
    keeps the generated module well-typed and matches the expected Relax IR.
    
    Testing:
    - `python -m pytest tests/python/relax/test_frontend_tflite.py -k
    "gather"`
    
    related to https://github.com/apache/tvm/issues/18971
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  3 ++
 tests/python/relax/test_frontend_tflite.py         | 58 ++++++++++++++++++++++
 2 files changed, 61 insertions(+)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index e66dff8356..f5b88b0c6a 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -1630,6 +1630,9 @@ class OperatorConverter:
 
         indices_dims = len(self._infer_shape(indices))
         indices_t = relax.op.permute_dims(indices, axes=[-1] + 
list(range(indices_dims - 1)))
+        if indices_type == TensorType.INT32:
+            # Relax gather_nd requires int64 indices.
+            indices_t = relax.op.astype(indices_t, "int64")
 
         out = relax.op.gather_nd(data, indices_t)
         return out
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 69e9b290fd..e4c237887e 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1451,6 +1451,64 @@ def test_reverse_v2():
 
     verify(ReverseV2, Expected)
 
+
+def test_gather():
+    class Gather(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32),
+                tf.TensorSpec(shape=(2,), dtype=tf.int64),
+            ]
+        )
+        def func(self, x, indices):
+            return tf.gather(x, indices, axis=1)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 4), dtype="float32"),
+            indices: R.Tensor((2,), dtype="int64"),
+        ) -> R.Tensor((2, 2, 4), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((2,), dtype="int32") = R.astype(indices, 
dtype="int32")
+                gv: R.Tensor((2, 2, 4), dtype="float32") = R.take(x, lv, 
axis=1, mode="fast")
+                R.output(gv)
+            return gv
+
+    verify(Gather, Expected)
+
+
+def test_gather_nd():
+    class GatherND(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32),
+                tf.TensorSpec(shape=(2, 2), dtype=tf.int32),
+            ]
+        )
+        def func(self, x, indices):
+            return tf.gather_nd(x, indices)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 4), dtype="float32"),
+            indices: R.Tensor((2, 2), dtype="int32"),
+        ) -> R.Tensor((2, 4), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((2, 2), dtype="int32") = R.permute_dims(indices, 
axes=[-1, 0])
+                lv1: R.Tensor((2, 2), dtype="int64") = R.astype(lv, 
dtype="int64")
+                gv: R.Tensor((2, 4), dtype="float32") = R.gather_nd(x, lv1, 
batch_dims=0)
+                R.output(gv)
+            return gv
+
+    verify(GatherND, Expected)
+
+
 def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, 
padding):
     class Conv2DModule(tf.Module):
         @tf.function(

Reply via email to