MasterJH5574 commented on code in PR #17836:
URL: https://github.com/apache/tvm/pull/17836#discussion_r2051100755


##########
python/tvm/relax/transform/legalize_ops/manipulate.py:
##########
@@ -162,6 +163,23 @@ def te_gather_nd(data, indices, batch_dims):
     return bb.call_te(te_gather_nd, call.args[0], call.args[1], 
int(call.attrs.batch_dims))
 
 
+@register_legalize("relax.index_tensor")
+def _index_tensor(bb: BlockBuilder, call: Call) -> Expr:
+    t = call.args[1]
+    n_field = len(t.struct_info.fields)
+    while isinstance(t, Var):
+        binding = bb.lookup_binding(t)
+        if not isinstance(binding, (Tuple, Var)):
+            break
+        t = binding
+
+    assert isinstance(t, (Tuple, Var))

Review Comment:
   It might help understanding if we can briefly explain why backtracing `t` is 
needed. My understanding is that without backtracing it's still correct, but 
may result in trivial TupleGetItems that can be easily removed.  Correct me if 
I'm wrong.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to