masahi commented on code in PR #13710: URL: https://github.com/apache/tvm/pull/13710#discussion_r1063081290
########## src/tir/ir/data_type_rewriter.cc: ########## @@ -397,6 +417,9 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) { Buffer new_buffer = GetRemappedBuffer(op->buffer); auto value = this->VisitExpr(op->value); + if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) { + value = cast(new_buffer->dtype, value); Review Comment: `new_buffer->dtype != value->dtype` is true if `new_buffer->dtype == float32` and `value->dtype == float32x16`, and the indices is a vector. `value->dtype.lanes() == 1` condition is added to prevent adding a cast in such cases. For argmax, this cast can result in a redundant chain of casts of the form int32(select(cond, int64(tru), int64(false))), but hopefully LLVM or other backend can clean them up. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org