kimm240 commented on code in PR #18636:
URL: https://github.com/apache/tvm/pull/18636#discussion_r2703900896


##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1308,53 +1179,163 @@ Block 
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
     var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
   }
 
-  // 2. Change init to epilogue value based on epilogue type
-  BufferStore new_init_store;
-  if (epilogue_type_ == EpilogueType::BiasReLU) {
-    // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU 
semantics
-    PrimExpr init_value = Substitute(epilogue_addend_, var_map);
-    PrimExpr zero = tir::make_zero(init_value.dtype());
-    new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, 
zero),
-                                 Substitute(epilogue_output_indices_, 
var_map));
-  } else if (epilogue_type_ == EpilogueType::Clipping) {
-    // For Clipping, init should be min(max(init_value, lower), upper)
-    // Since init is typically 0, this becomes min(max(0, lower), upper)
-    PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
-    PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, 
var_map)),
-                                Substitute(clipping_upper_, var_map));
-    new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
-                                 Substitute(epilogue_output_indices_, 
var_map));
-  } else {
-    // Bias: D[vi, vj] = C[vi, vj]
-    new_init_store = BufferStore(epilogue_output_buffer_, 
Substitute(epilogue_addend_, var_map),
-                                 Substitute(epilogue_output_indices_, 
var_map));
-  }
+  // 2. Generalized init transformation: substitute reduction buffer load with 
identity element (0)
+  // Create a substituter to replace reduction_buffer_load_ with identity 
element
+  class InitSubstituter : public ExprMutator {
+   public:
+    InitSubstituter(const Buffer& target_buffer, PrimExpr identity_elem)
+        : target_buffer_(target_buffer), identity_elem_(identity_elem) {}
+
+    PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+      BufferLoad load = Downcast<BufferLoad>(ExprMutator::VisitExpr_(op));
+      if (load->buffer.same_as(target_buffer_)) {
+        return identity_elem_;
+      }
+      return load;
+    }
+
+   private:
+    Buffer target_buffer_;
+    PrimExpr identity_elem_;
+  };
+
+  // Identity element for reduction (assumed to be 0 for addition-based 
reductions)
+  PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype);

Review Comment:
   For `BiasAdd`, the transformation $epilogue(\sum x) = \sum x + Bias$ holds 
regardless of whether the bias is handled at the initialization stage or the 
end. 
   
   However, for non-linear operations like `ReLU` and `Clipping`, this identity 
does not hold if we simply substitute the identity element in the `Init` block, 
as $f(\sum x) \neq f(0) + \sum x$.
   
   I initially aimed to generalize the logic by processing these operations 
through a unified substitution mechanism (essentially treating them in a 
'per-iteration')  to maintain the existing framework's structure-what is merged 
in #18515.
   
   Should we strictly move these non-linear transformations to the final Store 
stage (applied once to the final sum)?
   Or, since the previous 'per-iteration' behavior was already merged and used, 
should we keep it as an option or a specific 'fused-update' mode? I'm eager to 
hear about it. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to