gemini-code-assist[bot] commented on code in PR #18421:
URL: https://github.com/apache/tvm/pull/18421#discussion_r2493249865
##########
src/target/source/codegen_webgpu.cc:
##########
@@ -668,12 +668,19 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
+ std::string step = op->step.has_value()? PrintExpr(*op->step) : "";
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
- stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
+ stream << " = 0; " << vid << " < " << extent << "; " << vid;
+ if (step.empty()) {
+ stream << "++";
+ } else {
+ stream << " += " << step;
+ }
+ stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
Review Comment:

Similar to the C codegen, the WebGPU codegen for `ForNode` still assumes
`min` is zero. This should be updated to support non-zero `min` for consistency
with other backends like LLVM and SPIR-V.
```c
std::string min_str = PrintExpr(op->min);
std::string extent_str = PrintExpr(op->extent);
std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
std::string vid = AllocVarID(op->loop_var.get());
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
stream << " = " << min_str << "; " << vid << " < (" << min_str << " + " <<
extent_str << "); " << vid;
if (step_str.empty()) {
stream << "++";
} else {
stream << " += " << step_str;
}
stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
```
##########
src/target/spirv/codegen_spirv.cc:
##########
@@ -672,10 +672,19 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) {
}
void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
- ICHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
spirv::Value init_value = MakeValue(op->min);
spirv::Value extent_value = MakeValue(op->extent);
+
+ // loop step
+ spirv::Value step;
+ if (op->HasTrivialStep()) {
+ step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1)
+ :
builder_->UIntImm(loop_var.stype, 1);
+ } else {
+ step = MakeValue(tir::cast(op->extent->dtype, *op->step));
+ }
Review Comment:

There seems to be a bug in the step calculation. The code
`builder_->IntImm(loop_var.stype, 1)` uses `loop_var`, but the local `loop_var`
of `spirv::PhiValue` type is defined later in the function. This means it's
likely using the member variable `this->loop_var` from an outer loop, which is
incorrect. The type of the current loop variable `op->loop_var` should be used
instead.
```suggestion
if (op->HasTrivialStep()) {
spirv::SType stype = builder_->GetSType(op->loop_var.dtype());
step = op->loop_var.dtype().is_int() ? builder_->IntImm(stype, 1)
: builder_->UIntImm(stype, 1);
} else {
step = MakeValue(tir::cast(op->extent->dtype, *op->step));
}
```
##########
src/target/source/codegen_c.cc:
##########
@@ -1121,12 +1121,19 @@ void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
void CodeGenC::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
+ std::string step = op->step.has_value()? PrintExpr(*op->step) : "";
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
- stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid
<< ") {\n";
+ stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ";
+ if (step.empty()) {
+ stream << "++" << vid;
+ } else {
+ stream << vid << " += " << step;
+ }
+ stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
Review Comment:

The C codegen for `ForNode` still assumes that the loop starts from 0, as
indicated by the `ICHECK(is_zero(op->min))` and the hardcoded `vid = 0`. Other
backends like LLVM and SPIR-V have been updated to support non-zero `min`. For
consistency and to fully support generalized loops, this backend should also be
updated to handle non-zero `min`. This would involve removing the `ICHECK` and
using `op->min` for loop initialization and in the loop condition.
```c
std::string min_str = PrintExpr(op->min);
std::string extent_str = PrintExpr(op->extent);
std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
stream << ' ' << vid << " = " << min_str << "; " << vid << " < (" <<
min_str << " + " << extent_str << "); ";
if (step_str.empty()) {
stream << "++" << vid;
} else {
stream << vid << " += " << step_str;
}
stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
```
##########
src/te/operation/create_primfunc.cc:
##########
@@ -725,6 +725,8 @@ void RewriteStageToBlock(const te::Operation& op,
CreateFuncInfo* info,
} else if (const auto extern_op = op.as<te::ExternOp>()) {
// Case 3. ExternOp (te.extern)
root_stmts->push_back(GenerateStmtFromExternOp(extern_op.value(), info));
+ } else if (const auto scan_op = op.as<te::ScanOp>()) {
+
Review Comment:

This `if` branch for `te::ScanOp` is empty. If this is a placeholder for
future work, it would be better to add a `// TODO` comment explaining the
intent or remove it for now to avoid dead code.
##########
src/relax/distributed/transform/lower_global_view_to_local_view.cc:
##########
@@ -330,8 +330,8 @@ class DistributedBufferCompactor : StmtExprMutator {
if (shard > 1) {
arith::Analyzer analyzer;
ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0));
- return For(new_loop->loop_var, new_loop->min,
floordiv(new_loop->extent, shard),
- new_loop->kind, new_loop->body, new_loop->thread_binding,
new_loop->annotations);
+ new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard);
+ return new_loop;
Review Comment:

Using `CopyOnWrite` here is a great way to preserve the new `step` field and
any other future fields on the `ForNode`. This pattern has been applied
consistently across the PR, which is excellent.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]