masahi commented on code in PR #13710:
URL: https://github.com/apache/tvm/pull/13710#discussion_r1063081290


##########
src/tir/ir/data_type_rewriter.cc:
##########
@@ -397,6 +417,9 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const 
BufferStoreNode* op) {
 
   Buffer new_buffer = GetRemappedBuffer(op->buffer);
   auto value = this->VisitExpr(op->value);
+  if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
+    value = cast(new_buffer->dtype, value);

Review Comment:
   `new_buffer->dtype != value->dtype` is true if `new_buffer->dtype == 
float32` and `value->dtype == float32x16`, and the indices is a vector. 
`value->dtype.lanes() == 1` condition is added to prevent adding a cast in such 
cases.
   
   For argmax, this cast can result in a redundant chain of casts of the form 
int32(select(cond, int64(tru), int64(false))), but hopefully LLVM or other 
backend can clean them up. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to