gemini-code-assist[bot] commented on code in PR #18418:
URL: https://github.com/apache/tvm/pull/18418#discussion_r2489513727


##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -984,6 +984,393 @@ void ReverseComputeInline(ScheduleState self, const 
StmtSRef& consumer_block_sre
   ReverseComputeInlineImpl(self, consumer_block_sref);
 }
 
+/*!
+ * \brief Helper to fuse epilogue block into reduction block
+ * Analyzes epilogue pattern and transforms reduction init/update
+ */
+class ReductionEpilogueFuser : public BaseInliner {
+ public:
+  explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const 
BlockNode* reduction_block,
+                                  const BlockRealize& epilogue_block_realize,
+                                  const StmtSRef& scope_root_sref, const 
IRModule& mod)
+      : BaseInliner(reduction_buffer, epilogue_block_realize->block, 
scope_root_sref),
+        reduction_block_(reduction_block),
+        epilogue_block_(epilogue_block_realize->block.get()),
+        mod_(mod) {}
+
+  bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize);
+
+  // Step 2: Create single fused reduction block
+  Block CreateFusedReductionBlock(const BlockNode* reduction_block,
+                                  const BlockRealizeNode* reduction_realize);
+
+ private:
+  bool AnalyzeEpiloguePattern(const PrimExpr& value);
+  bool IsReductionBlock(const BlockNode* block);
+  void ExtractEpilogueInfo();
+  // Helper function to extract BufferLoad nodes from BufferStore
+  static std::vector<const BufferLoadNode*> ExtractBufferLoad(const Buffer& 
buffer,
+                                                              const 
BufferStoreNode* from) {
+    struct Extractor : public ExprVisitor {
+      void VisitExpr_(const BufferLoadNode* load) final {
+        if (load->buffer.get() == buffer) {
+          result.push_back(load);
+        }
+        ExprVisitor::VisitExpr_(load);
+      }
+      const BufferNode* buffer;
+      std::vector<const BufferLoadNode*> result;
+    } extractor;
+    extractor.buffer = buffer.get();
+    for (const PrimExpr& expr : from->indices) {
+      extractor(expr);
+    }
+    extractor(from->value);
+    return std::move(extractor.result);
+  }
+
+  const BlockNode* reduction_block_;
+  const BlockNode* epilogue_block_;
+  const IRModule& mod_;
+  PrimExpr epilogue_addend_{nullptr};                      // C[vi, vj] in D = 
temp + C
+  Buffer epilogue_output_buffer_{nullptr};                 // Output buffer D
+  ffi::Array<PrimExpr> epilogue_output_indices_{nullptr};  // Indices of D[vi, 
vj]
+  BufferRegion epilogue_output_region_{nullptr};           // Write region of D
+  Buffer epilogue_addend_buffer_{nullptr};                 // Addend buffer C
+  BufferRegion epilogue_addend_region_{nullptr};           // Read region of C
+};
+
+bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& 
epilogue_block_realize) {
+  const Block& epilogue_block = epilogue_block_realize->block;
+
+  // 1. Validate predicate
+  if (!is_one(epilogue_block_realize->predicate)) {
+    // Failure: Predicate in epilogue block is not supported
+    return false;
+  }
+
+  // 2. Check if epilogue body is BufferStore
+  if (inlined_store_ == nullptr) {
+    // Failure: epilogue block body is not BufferStore
+    return false;
+  }
+
+  // 3. Check if epilogue reads from reduction buffer
+  std::vector<const BufferLoadNode*> loads = 
ExtractBufferLoad(inlined_buffer_, inlined_store_);
+  if (loads.size() == 0) {
+    // Failure: no BufferLoad from the reduction buffer
+    return false;
+  }
+
+  // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j]
+  if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
+    // Failure: epilogue is not a simple addition pattern
+    return false;
+  }
+
+  // 5. Check if producer is a reduction block
+  if (!IsReductionBlock(reduction_block_)) {
+    // Failure: producer is not a reduction block
+    return false;
+  }
+
+  // 6. Extract epilogue information (output buffer, indices, regions, etc.)
+  ExtractEpilogueInfo();
+
+  return true;
+}
+
+bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
+  // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
+  if (const auto* add = value.as<AddNode>()) {
+    // Check if one operand is BufferLoad from reduction buffer
+    const auto* load_a = add->a.as<BufferLoadNode>();
+    const auto* load_b = add->b.as<BufferLoadNode>();
+
+    if (load_a && load_a->buffer.same_as(inlined_buffer_)) {
+      // Pattern: temp[...] + C[...]
+      epilogue_addend_ = add->b;
+      return true;
+    } else if (load_b && load_b->buffer.same_as(inlined_buffer_)) {
+      // Pattern: C[...] + temp[...]
+      epilogue_addend_ = add->a;
+      return true;
+    }
+  }
+
+  return false;
+}
+
+bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) {
+  // Check if block has reduction iter vars
+  for (const IterVar& iter : block->iter_vars) {
+    if (iter->iter_type == kCommReduce) {
+      return true;
+    }
+  }
+  return false;
+}
+
+void ReductionEpilogueFuser::ExtractEpilogueInfo() {
+  // Extract epilogue output buffer and indices
+  epilogue_output_buffer_ = inlined_store_->buffer;
+  epilogue_output_indices_ = inlined_store_->indices;
+
+  // Extract epilogue output region from epilogue block writes
+  for (const BufferRegion& write : epilogue_block_->writes) {
+    if (write->buffer.same_as(epilogue_output_buffer_)) {
+      epilogue_output_region_ = write;
+      break;
+    }
+  }
+
+  // Extract epilogue addend buffer and region from epilogue_addend_
+  if (const auto* load = epilogue_addend_.as<BufferLoadNode>()) {
+    epilogue_addend_buffer_ = load->buffer;
+    // Find the read region from epilogue block reads
+    for (const BufferRegion& read : epilogue_block_->reads) {
+      if (read->buffer.same_as(epilogue_addend_buffer_)) {
+        epilogue_addend_region_ = read;
+        break;
+      }
+    }
+  }
+}
+
+Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* 
reduction_block,
+                                                        const 
BlockRealizeNode* reduction_realize) {
+  ObjectPtr<BlockNode> new_block = 
ffi::make_object<BlockNode>(*reduction_block);
+
+  // 1. Keep all iter vars (data parallel + reduction)
+  new_block->iter_vars = reduction_block->iter_vars;
+
+  // 2. Map epilogue block vars to reduction block vars
+  std::unordered_map<Var, Var> var_map;
+  int reduction_data_par_idx = 0;
+  for (int i = 0; i < static_cast<int>(reduction_block->iter_vars.size()); 
++i) {
+    const IterVar& iter_var = reduction_block->iter_vars[i];
+    if (iter_var->iter_type == IterVarType::kDataPar) {
+      // Map corresponding data parallel var from epilogue block
+      int epilogue_data_par_idx = 0;
+      for (const IterVar& epilogue_iter_var : epilogue_block_->iter_vars) {
+        if (epilogue_iter_var->iter_type == IterVarType::kDataPar) {
+          if (epilogue_data_par_idx == reduction_data_par_idx) {
+            var_map[epilogue_iter_var->var] = iter_var->var;
+            break;
+          }
+          epilogue_data_par_idx++;
+        }
+      }
+      reduction_data_par_idx++;
+    }
+  }
+
+  // 3. Change init to epilogue value: D[vi, vj] = C[vi, vj]
+  BufferStore new_init_store(epilogue_output_buffer_, 
Substitute(epilogue_addend_, var_map),
+                             Substitute(epilogue_output_indices_, var_map));
+  new_block->init = new_init_store;
+
+  // 4. Replace output buffer from temp to D in body
+  class BufferReplacer : public StmtExprMutator {
+   public:
+    BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), 
new_buffer_(new_buf) {}
+
+    Stmt VisitStmt_(const BufferStoreNode* op) final {
+      BufferStore store = 
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+      if (store->buffer.same_as(old_buffer_)) {
+        return BufferStore(new_buffer_, store->value, store->indices);
+      }
+      return store;
+    }
+
+    PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+      BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+      if (load->buffer.same_as(old_buffer_)) {
+        return BufferLoad(new_buffer_, load->indices);
+      }
+      return load;
+    }
+
+   private:
+    Buffer old_buffer_;
+    Buffer new_buffer_;
+  };
+
+  BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_);
+  new_block->body = replacer(reduction_block->body);
+
+  // 5. Update write regions
+  ffi::Array<BufferRegion> new_writes;
+  for (const BufferRegion& write : reduction_block->writes) {
+    if (write->buffer.same_as(inlined_buffer_)) {
+      new_writes.push_back(
+          BufferRegion(epilogue_output_buffer_, Substitute(write->region, 
var_map)));
+    } else {
+      new_writes.push_back(write);
+    }
+  }
+  new_block->writes = new_writes;
+
+  // 6. Update read regions (C first, then A, B)
+  ffi::Array<BufferRegion> new_reads;
+  std::unordered_set<const BufferNode*> read_bufs;
+
+  // Add C buffer read first (used in init)
+  if (epilogue_addend_buffer_.defined()) {
+    new_reads.push_back(BufferRegion(epilogue_addend_buffer_,
+                                     
Substitute(epilogue_addend_region_->region, var_map)));
+    read_bufs.insert(epilogue_addend_buffer_.get());
+  }
+
+  // Add existing read regions (A, B, etc.)
+  for (const BufferRegion& read : reduction_block->reads) {
+    if (!read->buffer.same_as(inlined_buffer_)) {
+      // Only add non-temp buffers
+      if (read_bufs.find(read->buffer.get()) == read_bufs.end()) {
+        new_reads.push_back(read);
+        read_bufs.insert(read->buffer.get());
+      }
+    }
+  }
+
+  new_block->reads = new_reads;
+
+  return Block(new_block);
+}
+
+/*!
+ * \brief Helper class to replace reduction and epilogue blocks with a single 
fused block
+ */
+class SingleBlockFusionReplacer : public StmtMutator {
+ public:
+  static Block Replace(Block old_scope_root, Block new_fused_block, Block 
old_reduction_block,
+                       Block old_epilogue_block, Buffer reduction_buffer) {
+    SingleBlockFusionReplacer replacer(std::move(new_fused_block), 
std::move(old_reduction_block),
+                                       std::move(old_epilogue_block), 
std::move(reduction_buffer));
+    Block result = Downcast<Block>(replacer(std::move(old_scope_root)));
+
+    // Remove intermediate temp buffer
+    BlockNode* p = result.CopyOnWrite();
+    ffi::Array<Buffer> new_alloc_buffers;
+    for (const Buffer& buf : p->alloc_buffers) {
+      if (!buf.same_as(replacer.reduction_buffer_)) {
+        new_alloc_buffers.push_back(buf);
+      }
+    }
+    p->alloc_buffers = new_alloc_buffers;
+
+    return result;
+  }
+
+ private:
+  explicit SingleBlockFusionReplacer(Block new_fused_block, Block 
old_reduction_block,
+                                     Block old_epilogue_block, Buffer 
reduction_buffer)
+      : new_fused_block_(std::move(new_fused_block)),
+        old_reduction_block_(std::move(old_reduction_block)),
+        old_epilogue_block_(std::move(old_epilogue_block)),
+        reduction_buffer_(std::move(reduction_buffer)) {}
+
+  Stmt VisitStmt_(const ForNode* loop) final {
+    Stmt mutated_body = StmtMutator::VisitStmt(loop->body);
+    // Remove empty loops (containing only Evaluate(0))
+    if (auto eval = mutated_body.as<EvaluateNode>()) {
+      return mutated_body;  // Return Evaluate(0) to be removed by SeqStmt
+    }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The variable `eval` is declared but never used. This can be simplified to 
improve code clarity.
   
   ```c
       if (mutated_body.as<EvaluateNode>()) {
         return mutated_body;  // Return Evaluate(0) to be removed by SeqStmt
       }
   ```



##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -984,6 +984,393 @@ void ReverseComputeInline(ScheduleState self, const 
StmtSRef& consumer_block_sre
   ReverseComputeInlineImpl(self, consumer_block_sref);
 }
 
+/*!
+ * \brief Helper to fuse epilogue block into reduction block
+ * Analyzes epilogue pattern and transforms reduction init/update
+ */
+class ReductionEpilogueFuser : public BaseInliner {
+ public:
+  explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const 
BlockNode* reduction_block,
+                                  const BlockRealize& epilogue_block_realize,
+                                  const StmtSRef& scope_root_sref, const 
IRModule& mod)
+      : BaseInliner(reduction_buffer, epilogue_block_realize->block, 
scope_root_sref),
+        reduction_block_(reduction_block),
+        epilogue_block_(epilogue_block_realize->block.get()),
+        mod_(mod) {}
+
+  bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize);
+
+  // Step 2: Create single fused reduction block
+  Block CreateFusedReductionBlock(const BlockNode* reduction_block,
+                                  const BlockRealizeNode* reduction_realize);
+
+ private:
+  bool AnalyzeEpiloguePattern(const PrimExpr& value);
+  bool IsReductionBlock(const BlockNode* block);
+  void ExtractEpilogueInfo();
+  // Helper function to extract BufferLoad nodes from BufferStore
+  static std::vector<const BufferLoadNode*> ExtractBufferLoad(const Buffer& 
buffer,
+                                                              const 
BufferStoreNode* from) {
+    struct Extractor : public ExprVisitor {
+      void VisitExpr_(const BufferLoadNode* load) final {
+        if (load->buffer.get() == buffer) {
+          result.push_back(load);
+        }
+        ExprVisitor::VisitExpr_(load);
+      }
+      const BufferNode* buffer;
+      std::vector<const BufferLoadNode*> result;
+    } extractor;
+    extractor.buffer = buffer.get();
+    for (const PrimExpr& expr : from->indices) {
+      extractor(expr);
+    }
+    extractor(from->value);
+    return std::move(extractor.result);
+  }
+
+  const BlockNode* reduction_block_;
+  const BlockNode* epilogue_block_;
+  const IRModule& mod_;
+  PrimExpr epilogue_addend_{nullptr};                      // C[vi, vj] in D = 
temp + C
+  Buffer epilogue_output_buffer_{nullptr};                 // Output buffer D
+  ffi::Array<PrimExpr> epilogue_output_indices_{nullptr};  // Indices of D[vi, 
vj]
+  BufferRegion epilogue_output_region_{nullptr};           // Write region of D
+  Buffer epilogue_addend_buffer_{nullptr};                 // Addend buffer C
+  BufferRegion epilogue_addend_region_{nullptr};           // Read region of C
+};
+
+bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& 
epilogue_block_realize) {
+  const Block& epilogue_block = epilogue_block_realize->block;
+
+  // 1. Validate predicate
+  if (!is_one(epilogue_block_realize->predicate)) {
+    // Failure: Predicate in epilogue block is not supported
+    return false;
+  }
+
+  // 2. Check if epilogue body is BufferStore
+  if (inlined_store_ == nullptr) {
+    // Failure: epilogue block body is not BufferStore
+    return false;
+  }
+
+  // 3. Check if epilogue reads from reduction buffer
+  std::vector<const BufferLoadNode*> loads = 
ExtractBufferLoad(inlined_buffer_, inlined_store_);
+  if (loads.size() == 0) {
+    // Failure: no BufferLoad from the reduction buffer
+    return false;
+  }
+
+  // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j]
+  if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
+    // Failure: epilogue is not a simple addition pattern
+    return false;
+  }
+
+  // 5. Check if producer is a reduction block
+  if (!IsReductionBlock(reduction_block_)) {
+    // Failure: producer is not a reduction block
+    return false;
+  }
+
+  // 6. Extract epilogue information (output buffer, indices, regions, etc.)
+  ExtractEpilogueInfo();
+
+  return true;
+}
+
+bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
+  // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
+  if (const auto* add = value.as<AddNode>()) {
+    // Check if one operand is BufferLoad from reduction buffer
+    const auto* load_a = add->a.as<BufferLoadNode>();
+    const auto* load_b = add->b.as<BufferLoadNode>();
+
+    if (load_a && load_a->buffer.same_as(inlined_buffer_)) {
+      // Pattern: temp[...] + C[...]
+      epilogue_addend_ = add->b;
+      return true;
+    } else if (load_b && load_b->buffer.same_as(inlined_buffer_)) {
+      // Pattern: C[...] + temp[...]
+      epilogue_addend_ = add->a;
+      return true;
+    }
+  }
+
+  return false;
+}
+
+bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) {
+  // Check if block has reduction iter vars
+  for (const IterVar& iter : block->iter_vars) {
+    if (iter->iter_type == kCommReduce) {
+      return true;
+    }
+  }
+  return false;
+}
+
+void ReductionEpilogueFuser::ExtractEpilogueInfo() {
+  // Extract epilogue output buffer and indices
+  epilogue_output_buffer_ = inlined_store_->buffer;
+  epilogue_output_indices_ = inlined_store_->indices;
+
+  // Extract epilogue output region from epilogue block writes
+  for (const BufferRegion& write : epilogue_block_->writes) {
+    if (write->buffer.same_as(epilogue_output_buffer_)) {
+      epilogue_output_region_ = write;
+      break;
+    }
+  }
+
+  // Extract epilogue addend buffer and region from epilogue_addend_
+  if (const auto* load = epilogue_addend_.as<BufferLoadNode>()) {
+    epilogue_addend_buffer_ = load->buffer;
+    // Find the read region from epilogue block reads
+    for (const BufferRegion& read : epilogue_block_->reads) {
+      if (read->buffer.same_as(epilogue_addend_buffer_)) {
+        epilogue_addend_region_ = read;
+        break;
+      }
+    }
+  }
+}
+
+Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* 
reduction_block,
+                                                        const 
BlockRealizeNode* reduction_realize) {
+  ObjectPtr<BlockNode> new_block = 
ffi::make_object<BlockNode>(*reduction_block);
+
+  // 1. Keep all iter vars (data parallel + reduction)
+  new_block->iter_vars = reduction_block->iter_vars;
+
+  // 2. Map epilogue block vars to reduction block vars
+  std::unordered_map<Var, Var> var_map;
+  int reduction_data_par_idx = 0;
+  for (int i = 0; i < static_cast<int>(reduction_block->iter_vars.size()); 
++i) {
+    const IterVar& iter_var = reduction_block->iter_vars[i];
+    if (iter_var->iter_type == IterVarType::kDataPar) {
+      // Map corresponding data parallel var from epilogue block
+      int epilogue_data_par_idx = 0;
+      for (const IterVar& epilogue_iter_var : epilogue_block_->iter_vars) {
+        if (epilogue_iter_var->iter_type == IterVarType::kDataPar) {
+          if (epilogue_data_par_idx == reduction_data_par_idx) {
+            var_map[epilogue_iter_var->var] = iter_var->var;
+            break;
+          }
+          epilogue_data_par_idx++;
+        }
+      }
+      reduction_data_par_idx++;
+    }
+  }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The nested loop for mapping data-parallel `IterVar`s is functionally correct 
but could be simplified for better readability and maintainability. A clearer 
approach would be to extract the data-parallel variables from both blocks into 
separate lists and then map them by index. This also provides a natural place 
to add a check to ensure the number of data-parallel variables match, improving 
robustness.
   
   ```c
     std::vector<Var> reduction_data_vars;
     for (const IterVar& iter_var : reduction_block->iter_vars) {
       if (iter_var->iter_type == IterVarType::kDataPar) {
         reduction_data_vars.push_back(iter_var->var);
       }
     }
     std::vector<Var> epilogue_data_vars;
     for (const IterVar& iter_var : epilogue_block_->iter_vars) {
       if (iter_var->iter_type == IterVarType::kDataPar) {
         epilogue_data_vars.push_back(iter_var->var);
       }
     }
   
     ICHECK_EQ(reduction_data_vars.size(), epilogue_data_vars.size())
         << "ValueError: The number of data parallel iter vars must be the same 
in the reduction "
            "and epilogue blocks.";
   
     std::unordered_map<Var, Var> var_map;
     for (size_t i = 0; i < reduction_data_vars.size(); ++i) {
       var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
     }
   ```



##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -984,6 +984,393 @@ void ReverseComputeInline(ScheduleState self, const 
StmtSRef& consumer_block_sre
   ReverseComputeInlineImpl(self, consumer_block_sref);
 }
 
+/*!
+ * \brief Helper to fuse epilogue block into reduction block
+ * Analyzes epilogue pattern and transforms reduction init/update
+ */
+class ReductionEpilogueFuser : public BaseInliner {
+ public:
+  explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const 
BlockNode* reduction_block,
+                                  const BlockRealize& epilogue_block_realize,
+                                  const StmtSRef& scope_root_sref, const 
IRModule& mod)
+      : BaseInliner(reduction_buffer, epilogue_block_realize->block, 
scope_root_sref),
+        reduction_block_(reduction_block),
+        epilogue_block_(epilogue_block_realize->block.get()),
+        mod_(mod) {}
+
+  bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize);
+
+  // Step 2: Create single fused reduction block
+  Block CreateFusedReductionBlock(const BlockNode* reduction_block,
+                                  const BlockRealizeNode* reduction_realize);
+
+ private:
+  bool AnalyzeEpiloguePattern(const PrimExpr& value);
+  bool IsReductionBlock(const BlockNode* block);
+  void ExtractEpilogueInfo();
+  // Helper function to extract BufferLoad nodes from BufferStore
+  static std::vector<const BufferLoadNode*> ExtractBufferLoad(const Buffer& 
buffer,
+                                                              const 
BufferStoreNode* from) {
+    struct Extractor : public ExprVisitor {
+      void VisitExpr_(const BufferLoadNode* load) final {
+        if (load->buffer.get() == buffer) {
+          result.push_back(load);
+        }
+        ExprVisitor::VisitExpr_(load);
+      }
+      const BufferNode* buffer;
+      std::vector<const BufferLoadNode*> result;
+    } extractor;
+    extractor.buffer = buffer.get();
+    for (const PrimExpr& expr : from->indices) {
+      extractor(expr);
+    }
+    extractor(from->value);
+    return std::move(extractor.result);
+  }
+
+  const BlockNode* reduction_block_;
+  const BlockNode* epilogue_block_;
+  const IRModule& mod_;
+  PrimExpr epilogue_addend_{nullptr};                      // C[vi, vj] in D = 
temp + C
+  Buffer epilogue_output_buffer_{nullptr};                 // Output buffer D
+  ffi::Array<PrimExpr> epilogue_output_indices_{nullptr};  // Indices of D[vi, 
vj]
+  BufferRegion epilogue_output_region_{nullptr};           // Write region of D
+  Buffer epilogue_addend_buffer_{nullptr};                 // Addend buffer C
+  BufferRegion epilogue_addend_region_{nullptr};           // Read region of C
+};
+
+bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& 
epilogue_block_realize) {
+  const Block& epilogue_block = epilogue_block_realize->block;
+
+  // 1. Validate predicate
+  if (!is_one(epilogue_block_realize->predicate)) {
+    // Failure: Predicate in epilogue block is not supported
+    return false;
+  }
+
+  // 2. Check if epilogue body is BufferStore
+  if (inlined_store_ == nullptr) {
+    // Failure: epilogue block body is not BufferStore
+    return false;
+  }
+
+  // 3. Check if epilogue reads from reduction buffer
+  std::vector<const BufferLoadNode*> loads = 
ExtractBufferLoad(inlined_buffer_, inlined_store_);
+  if (loads.size() == 0) {
+    // Failure: no BufferLoad from the reduction buffer
+    return false;
+  }
+
+  // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j]
+  if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
+    // Failure: epilogue is not a simple addition pattern
+    return false;
+  }
+
+  // 5. Check if producer is a reduction block
+  if (!IsReductionBlock(reduction_block_)) {
+    // Failure: producer is not a reduction block
+    return false;
+  }
+
+  // 6. Extract epilogue information (output buffer, indices, regions, etc.)
+  ExtractEpilogueInfo();
+
+  return true;
+}
+
+bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
+  // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
+  if (const auto* add = value.as<AddNode>()) {
+    // Check if one operand is BufferLoad from reduction buffer
+    const auto* load_a = add->a.as<BufferLoadNode>();
+    const auto* load_b = add->b.as<BufferLoadNode>();
+
+    if (load_a && load_a->buffer.same_as(inlined_buffer_)) {
+      // Pattern: temp[...] + C[...]
+      epilogue_addend_ = add->b;
+      return true;
+    } else if (load_b && load_b->buffer.same_as(inlined_buffer_)) {
+      // Pattern: C[...] + temp[...]
+      epilogue_addend_ = add->a;
+      return true;
+    }
+  }
+
+  return false;
+}

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The current logic in `AnalyzeEpiloguePattern` can incorrectly succeed for 
expressions like `temp + temp`, where `temp` is the output of the reduction. 
The `if`/`else if` structure will match the first operand and return `true`, 
setting `epilogue_addend_` to the second `temp` operand. This leads to an 
incorrect transformation. The pattern should only be matched if exactly one of 
the addition operands is a load from the reduction buffer.
   
   ```c
   bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
     // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
     if (const auto* add = value.as<AddNode>()) {
       const auto* load_a = add->a.as<BufferLoadNode>();
       const auto* load_b = add->b.as<BufferLoadNode>();
   
       bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
       bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
   
       if (a_is_target != b_is_target) {
         epilogue_addend_ = a_is_target ? add->b : add->a;
         return true;
       }
     }
   
     return false;
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to