Lunderberg commented on code in PR #12720:
URL: https://github.com/apache/tvm/pull/12720#discussion_r972246269


##########
python/tvm/tir/function.py:
##########
@@ -389,17 +389,27 @@ def from_func_with_separators(mapping_function: Callable, 
ndim: Optional[int] =
 
         final_indices = []
         axis_separators = []
-        for val in mapping:
-            if isinstance(val, tvm.ir.PrimExpr):
-                final_indices.append(val)
-            elif val is IndexMap.AXIS_SEPARATOR:
-                axis_separators.append(len(final_indices))
-            else:
-                raise TypeError(
-                    "Expected mapping function to return list of "
-                    "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR.  "
-                    f"Instead received {val} of type {type(val)}."
-                )
+
+        try:
+            iter(mapping)

Review Comment:
   This was to allow the mapping function to return a single `PrimExpr`, or 
something that the ffi can convert into a `PrimExpr`.  Since it wouldn't make 
sense for the pad value to provide multiple outputs, I found myself frequently 
writing `lambda i,j : i+j` instead of `lambda i,j: [i+j]`.  I figured that 
since I was frequently making that mistake, later users would also likely make 
it as well, so it would be best to support that functionality.
   
   Good call on the documentation, and I'll update the documentation for 
`from_func` and `from_func_with_separators` accordingly.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to