This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8521c2f7f3 [TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern 
support (#18515)
8521c2f7f3 is described below

commit 8521c2f7f3a9db5393ddf37761264f1657d15206
Author: kimm240 <[email protected]>
AuthorDate: Thu Dec 25 21:27:23 2025 +0900

    [TIR][Schedule] FuseReductionEpilogue: Add Clipping pattern support (#18515)
    
    Currently, the FuseReductionEpilogue primitive only supports Bias
    (addition) and BiasReLU (addition + ReLU) epilogue patterns. However,
    clipping operations (min(max(x, lower), upper)) are commonly used in
    deep learning models and would benefit from the same fusion
    optimization.
    
    This commit extends FuseReductionEpilogue to support Clipping patterns
    by:
    
    1. Adding EpilogueType::Clipping to the enum to distinguish clipping
       patterns from other epilogue types.
    
    2. Adding clipping_lower_ and clipping_upper_ members to
       ReductionEpilogueFuser to store clipping bounds extracted from the
       epilogue pattern.
    
    3. Extending AnalyzeEpiloguePattern to detect clipping patterns:
       - min(max(temp, lower), upper)
       - max(min(temp, upper), lower)
       - All commutative variants of min/max at each level
    
    4. Updating BiasReLU pattern matching to handle max(0, x) form in
       addition to max(x, 0) for better commutativity support.
    
    5. Modifying CreateFusedReductionBlock to apply clipping to the init
       value: init = min(max(0, lower), upper)
    
    6. Updating BufferReplacer to apply clipping per-iteration:
       value = min(max(value, lower), upper)
    
    7. Adding validation in BodyPatternAllowFusion to ensure temp appears
       exactly once in clipping patterns.
    
    8. Creating comprehensive test coverage with 8 test cases:
       - Basic fusion test
       - Numerical correctness verification
       - Multiple epilogue blocks test
       - 5 commutative variant tests
    
    This implementation follows the same per-iteration semantics as
    BiasReLU,
    where clipping is applied at each reduction step rather than
    post-reduction. This semantic change is documented in the docstring with
    a warning about potential numerical differences.
    
    The test suite verifies that all commutative forms of clipping patterns
    are correctly recognized and that the fused implementation produces
    numerically identical results to the per-iteration reference
    implementation.
    
    ---------
    
    Co-authored-by: hyun gyu kim <[email protected]>
---
 .gitignore                                         |   3 +
 python/tvm/tir/schedule/schedule.py                |  31 ++-
 src/tir/schedule/primitive/compute_inline.cc       | 224 +++++++++++++++--
 ...ir_schedule_fuse_reduction_epilogue_clipping.py | 271 +++++++++++++++++++++
 ...st_tir_schedule_fuse_reduction_epilogue_relu.py | 229 +++++++++++++++++
 5 files changed, 741 insertions(+), 17 deletions(-)

diff --git a/.gitignore b/.gitignore
index 5bcbd5e373..6fa10a5e76 100644
--- a/.gitignore
+++ b/.gitignore
@@ -274,3 +274,6 @@ tvm-site/
 
 # GDB history file
 .gdb_history
+
+# Less command history file
+.lesshst
diff --git a/python/tvm/tir/schedule/schedule.py 
b/python/tvm/tir/schedule/schedule.py
index 0d41ffe943..b1e1a3f5d5 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -2356,14 +2356,41 @@ class Schedule(Object):
         It requires:
         1) The reduction block is a complete reduction block
         2) The epilogue block only reads from the reduction block's output
-        3) The epilogue performs a simple addition: output = reduction_result 
+ bias
+        3) The epilogue matches one of the supported patterns:
+           - Bias: ``output = reduction_result + bias``
+           - BiasReLU: ``output = max(reduction_result + bias, 0)``
+           - Clipping: ``output = min(max(reduction_result, lower), upper)``
+             or their commutative variants
+
+        .. warning::
+
+            **Semantic Change for Non-Linear Epilogues (BiasReLU, Clipping):**
+
+            For non-linear epilogues (BiasReLU and Clipping), fusion changes 
the
+            computation semantics from post-reduction application to 
per-iteration
+            application. This can lead to different numerical results.
+
+            **Example with Clipping to [-5, 5] and inputs [6, -2]:**
+
+            - **Post-reduction clipping** (original): ``clip(sum([6, -2])) = 
clip(4) = 4``
+            - **Per-iteration clipping** (fused): ``acc=0 → clip(0+6)=5 → 
clip(5+(-2))=3``
+
+            The fused version applies clipping at each reduction iteration, 
which
+            may be an intended optimization for some models but can cause 
unexpected
+            correctness issues if users are not aware of this behavior.
+
+            For linear epilogues (Bias), fusion preserves exact numerical 
equivalence.
 
         Parameters
         ----------
         reduction_block : Union[BlockRV, str]
             The reduction block (e.g., matmul)
         epilogue_block : Union[BlockRV, str]
-            The epilogue block to be fused (e.g., bias add)
+            The epilogue block to be fused (e.g., bias add, ReLU, clipping)
+
+        Examples
+        --------
+        See :py:func:`test_tir_schedule_fuse_reduction_epilogue` for examples.
         """
         reduction_block = self._normalize_block_arg(reduction_block)
         epilogue_block = self._normalize_block_arg(epilogue_block)
diff --git a/src/tir/schedule/primitive/compute_inline.cc 
b/src/tir/schedule/primitive/compute_inline.cc
index cc3785d5c1..0ab6d7e2b6 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -988,6 +988,13 @@ void ReverseComputeInline(ScheduleState self, const 
StmtSRef& consumer_block_sre
  * \brief Helper to fuse epilogue block into reduction block
  * Analyzes epilogue pattern and transforms reduction init/update
  */
+// Epilogue type enumeration
+enum class EpilogueType {
+  Bias,      // temp + C
+  BiasReLU,  // max(temp + C, 0)
+  Clipping,  // min(max(temp, lower), upper)
+};
+
 class ReductionEpilogueFuser : public BaseInliner {
  public:
   explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const 
BlockNode* reduction_block,
@@ -995,7 +1002,19 @@ class ReductionEpilogueFuser : public BaseInliner {
                                   const StmtSRef& scope_root_sref)
       : BaseInliner(reduction_buffer, epilogue_block_realize->block, 
scope_root_sref),
         reduction_block_(reduction_block),
-        epilogue_block_(epilogue_block_realize->block.get()) {}
+        epilogue_block_(epilogue_block_realize->block.get()),
+        epilogue_type_(EpilogueType::Bias) {
+    // Disable opaque access check for epilogue fusion
+    // Epilogue blocks can read multiple buffers (temp + bias), which is 
allowed
+    has_opaque_access = false;
+  }
+
+  // Override CheckOpaqueAccess to allow multiple buffer reads
+  void CheckOpaqueAccess(const VarNode* buffer_var) {
+    // For epilogue fusion, we allow multiple buffer reads (temp + bias)
+    // So we don't check for opaque access
+    // BaseInliner::CheckOpaqueAccess(buffer_var);  // Don't call base class
+  }
 
   bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize);
 
@@ -1012,18 +1031,21 @@ class ReductionEpilogueFuser : public BaseInliner {
                                                               const 
BufferStoreNode* from) {
     struct Extractor : public ExprVisitor {
       void VisitExpr_(const BufferLoadNode* load) final {
-        if (load->buffer.get() == buffer) {
+        if (load->buffer.same_as(buffer)) {
           result.push_back(load);
         }
+        // Continue visiting child nodes (indices)
         ExprVisitor::VisitExpr_(load);
       }
-      const BufferNode* buffer;
+      Buffer buffer;
       std::vector<const BufferLoadNode*> result;
     } extractor;
-    extractor.buffer = buffer.get();
+    extractor.buffer = buffer;
+    // Visit indices first (though they typically don't contain BufferLoad)
     for (const PrimExpr& expr : from->indices) {
       extractor(expr);
     }
+    // Visit the value expression (e.g., max(temp + C, 0) for ReLU)
     extractor(from->value);
     return std::move(extractor.result);
   }
@@ -1036,6 +1058,9 @@ class ReductionEpilogueFuser : public BaseInliner {
   BufferRegion epilogue_output_region_{nullptr};           // Write region of D
   Buffer epilogue_addend_buffer_{nullptr};                 // Addend buffer C
   BufferRegion epilogue_addend_region_{nullptr};           // Read region of C
+  EpilogueType epilogue_type_;                             // Type of epilogue 
operation
+  PrimExpr clipping_lower_{nullptr};                       // Lower bound for 
clipping
+  PrimExpr clipping_upper_{nullptr};                       // Upper bound for 
clipping
 };
 
 bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& 
epilogue_block_realize) {
@@ -1058,26 +1083,36 @@ bool 
ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue
     return false;
   }
 
-  // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j]
+  // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or
+  //    D[i,j] = min(max(temp[i,j], lower), upper)
   if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
-    // Failure: epilogue is not a simple addition pattern
+    // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or 
Clipping)
+    return false;
+  }
+
+  // 5. Verify temp appears exactly once in the epilogue pattern
+  // This ensures correctness for all supported patterns (Bias, BiasReLU, 
Clipping)
+  // The reduction result buffer must be used exactly once in the epilogue 
expression
+  if (loads.size() != 1) {
+    // Failure: The reduction result (temp) must be used exactly once in the
+    // epilogue expression for fusion.
     return false;
   }
 
-  // 5. Check if producer is a reduction block
+  // 6. Check if producer is a reduction block
   if (!IsReductionBlock(reduction_block_)) {
     // Failure: producer is not a reduction block
     return false;
   }
 
-  // 6. Extract epilogue information (output buffer, indices, regions, etc.)
+  // 7. Extract epilogue information (output buffer, indices, regions, etc.)
   ExtractEpilogueInfo();
 
   return true;
 }
 
 bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
-  // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
+  // Pattern 1: temp[i,j] + C[i,j] or C[i,j] + temp[i,j] (Bias)
   if (const auto* add = value.as<AddNode>()) {
     const auto* load_a = add->a.as<BufferLoadNode>();
     const auto* load_b = add->b.as<BufferLoadNode>();
@@ -1088,10 +1123,125 @@ bool 
ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
     // Ensure exactly one operand is from the reduction buffer
     if (a_is_target != b_is_target) {
       epilogue_addend_ = a_is_target ? add->b : add->a;
+      epilogue_type_ = EpilogueType::Bias;
       return true;
     }
   }
 
+  // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], 
upper), lower) (Clipping)
+  // Handle all commutative variants of min/max at each level.
+
+  // Helper to check if an expression is a load from the reduction buffer, and
+  // return the other operand as `other` if so.
+  auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const 
PrimExpr& b,
+                                               PrimExpr* other) -> bool {
+    if (const auto* load_a = a.as<BufferLoadNode>()) {
+      if (load_a->buffer.same_as(inlined_buffer_)) {
+        *other = b;
+        return true;
+      }
+    }
+    if (const auto* load_b = b.as<BufferLoadNode>()) {
+      if (load_b->buffer.same_as(inlined_buffer_)) {
+        *other = a;
+        return true;
+      }
+    }
+    return false;
+  };
+
+  // Check for min(max(temp, lower), upper) and commutative variants
+  if (const auto* min_node = value.as<MinNode>()) {
+    const MaxNode* max_node = nullptr;
+    PrimExpr upper;
+    // Try both (a, b) as possible positions of the inner max
+    if ((max_node = min_node->a.as<MaxNode>())) {
+      upper = min_node->b;
+    } else if ((max_node = min_node->b.as<MaxNode>())) {
+      upper = min_node->a;
+    }
+    if (max_node != nullptr) {
+      PrimExpr lower;
+      if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) {
+        clipping_lower_ = lower;
+        clipping_upper_ = upper;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+  }
+
+  // Check for max(min(temp[i,j], upper), lower) and commutative variants
+  if (const auto* max_node = value.as<MaxNode>()) {
+    const MinNode* min_node = nullptr;
+    PrimExpr lower;
+    // Try both (a, b) as possible positions of the inner min
+    if ((min_node = max_node->a.as<MinNode>())) {
+      lower = max_node->b;
+    } else if ((min_node = max_node->b.as<MinNode>())) {
+      lower = max_node->a;
+    }
+    if (min_node != nullptr) {
+      PrimExpr upper;
+      if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) {
+        clipping_lower_ = lower;
+        clipping_upper_ = upper;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+  }
+
+  // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) 
(BiasReLU)
+  // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j])
+  if (const auto* max_node = value.as<MaxNode>()) {
+    // Check if either operand is zero (ReLU: max(x, 0) or max(0, x))
+    // Support both integer and float zero constants.
+    const PrimExpr* add_candidate = nullptr;
+    bool is_zero_const = false;
+    auto is_zero_expr = [](const PrimExpr& expr) -> bool {
+      if (tir::is_zero(expr)) {
+        return true;
+      }
+      if (const auto* float_imm = expr.as<FloatImmNode>()) {
+        return float_imm->value == 0.0;
+      }
+      return false;
+    };
+
+    if (is_zero_expr(max_node->a)) {
+      is_zero_const = true;
+      add_candidate = &max_node->b;
+    } else if (is_zero_expr(max_node->b)) {
+      is_zero_const = true;
+      add_candidate = &max_node->a;
+    }
+
+    if (is_zero_const && add_candidate != nullptr) {
+      if (const auto* add = add_candidate->as<AddNode>()) {
+        const auto* load_a = add->a.as<BufferLoadNode>();
+        const auto* load_b = add->b.as<BufferLoadNode>();
+
+        bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
+        bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
+
+        // Ensure exactly one operand is from the reduction buffer
+        if (a_is_target != b_is_target) {
+          epilogue_addend_ = a_is_target ? add->b : add->a;
+          epilogue_type_ = EpilogueType::BiasReLU;
+          return true;
+        }
+      } else if (const auto* load = add_candidate->as<BufferLoadNode>()) {
+        // Handle bias-free ReLU: max(temp, 0) or max(0, temp)
+        if (load->buffer.same_as(inlined_buffer_)) {
+          epilogue_addend_ = tir::make_zero(load->dtype);
+          epilogue_type_ = EpilogueType::BiasReLU;
+          return true;
+        }
+      }
+    }
+  }
+
   return false;
 }
 
@@ -1158,20 +1308,54 @@ Block 
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
     var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
   }
 
-  // 2. Change init to epilogue value: D[vi, vj] = C[vi, vj]
-  BufferStore new_init_store(epilogue_output_buffer_, 
Substitute(epilogue_addend_, var_map),
-                             Substitute(epilogue_output_indices_, var_map));
+  // 2. Change init to epilogue value based on epilogue type
+  BufferStore new_init_store;
+  if (epilogue_type_ == EpilogueType::BiasReLU) {
+    // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU 
semantics
+    PrimExpr init_value = Substitute(epilogue_addend_, var_map);
+    PrimExpr zero = tir::make_zero(init_value.dtype());
+    new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value, 
zero),
+                                 Substitute(epilogue_output_indices_, 
var_map));
+  } else if (epilogue_type_ == EpilogueType::Clipping) {
+    // For Clipping, init should be min(max(init_value, lower), upper)
+    // Since init is typically 0, this becomes min(max(0, lower), upper)
+    PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
+    PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_, 
var_map)),
+                                Substitute(clipping_upper_, var_map));
+    new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
+                                 Substitute(epilogue_output_indices_, 
var_map));
+  } else {
+    // Bias: D[vi, vj] = C[vi, vj]
+    new_init_store = BufferStore(epilogue_output_buffer_, 
Substitute(epilogue_addend_, var_map),
+                                 Substitute(epilogue_output_indices_, 
var_map));
+  }
   new_block->init = new_init_store;
 
   // 3. Replace output buffer from temp to D in body
   class BufferReplacer : public StmtExprMutator {
    public:
-    BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf), 
new_buffer_(new_buf) {}
+    BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type, 
DataType dtype,
+                   PrimExpr clipping_lower = PrimExpr(), PrimExpr 
clipping_upper = PrimExpr())
+        : old_buffer_(old_buf),
+          new_buffer_(new_buf),
+          epilogue_type_(epilogue_type),
+          dtype_(dtype),
+          clipping_lower_(clipping_lower),
+          clipping_upper_(clipping_upper) {}
 
     Stmt VisitStmt_(const BufferStoreNode* op) final {
       BufferStore store = 
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
       if (store->buffer.same_as(old_buffer_)) {
-        return BufferStore(new_buffer_, store->value, store->indices);
+        PrimExpr new_value = store->value;
+        // For ReLU, apply max per iteration to match per-iteration ReLU 
semantics
+        if (epilogue_type_ == EpilogueType::BiasReLU) {
+          PrimExpr zero = tir::make_zero(dtype_);
+          new_value = Max(new_value, zero);
+        } else if (epilogue_type_ == EpilogueType::Clipping) {
+          // For Clipping, apply min(max(value, lower), upper) per iteration
+          new_value = Min(Max(new_value, clipping_lower_), clipping_upper_);
+        }
+        return BufferStore(new_buffer_, new_value, store->indices);
       }
       return store;
     }
@@ -1187,9 +1371,19 @@ Block 
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
    private:
     Buffer old_buffer_;
     Buffer new_buffer_;
+    EpilogueType epilogue_type_;
+    DataType dtype_;
+    PrimExpr clipping_lower_;
+    PrimExpr clipping_upper_;
   };
 
-  BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_);
+  DataType dtype = epilogue_output_buffer_->dtype;
+  PrimExpr clipping_lower_subst =
+      epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_lower_, 
var_map) : PrimExpr();
+  PrimExpr clipping_upper_subst =
+      epilogue_type_ == EpilogueType::Clipping ? Substitute(clipping_upper_, 
var_map) : PrimExpr();
+  BufferReplacer replacer(inlined_buffer_, epilogue_output_buffer_, 
epilogue_type_, dtype,
+                          clipping_lower_subst, clipping_upper_subst);
   new_block->body = replacer(reduction_block->body);
 
   // 4. Update write regions
diff --git 
a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py
 
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py
new file mode 100644
index 0000000000..6b3338b9a1
--- /dev/null
+++ 
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py
@@ -0,0 +1,271 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import (
+    verify_trace_roundtrip,
+    assert_structural_equal_ignore_global_symbol,
+)
+import numpy as np
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
[email protected]_func
+def matmul_clipping_before(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+    lower: T.float32,
+    upper: T.float32,
+) -> None:
+    """Original function with separate reduction and clipping epilogue 
blocks."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                temp[vi, vj] = T.float32(0)
+            temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+    for i, j in T.grid(16, 16):
+        with T.block("clipping"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper)
+
+
[email protected]_func
+def matmul_clipping_expected(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+    lower: T.float32,
+    upper: T.float32,
+) -> None:
+    """Expected function after fusion (Clipping)."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            T.reads(A[vi, vk], B[vj, vk])
+            T.writes(D[vi, vj])
+            with T.init():
+                D[vi, vj] = T.min(T.max(T.float32(0), lower), upper)
+            D[vi, vj] = T.min(T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], lower), 
upper)
+
+
+def test_matmul_clipping():
+    """Test fusion of matmul with clipping epilogue."""
+    sch = tir.Schedule(matmul_clipping_before, debug_mask="all")
+    sch.fuse_reduction_epilogue("matmul", "clipping")
+    assert_structural_equal_ignore_global_symbol(sch.mod["main"], 
matmul_clipping_expected)
+    verify_trace_roundtrip(sch=sch, mod=matmul_clipping_before)
+
+
[email protected]_func
+def matmul_clipping_before_per_iteration(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+) -> None:
+    """Original function with per-iteration clipping (same semantics as 
fused)."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    lower = T.float32(-5.0)
+    upper = T.float32(5.0)
+    for i, j in T.grid(16, 16):
+        with T.block("init"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            temp[vi, vj] = T.min(T.max(T.float32(0), lower), upper)  # Clip 
init
+
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            # Per-iteration clipping
+            temp[vi, vj] = T.min(T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk], 
lower), upper)
+
+    for i, j in T.grid(16, 16):
+        with T.block("copy"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            D[vi, vj] = temp[vi, vj]
+
+
+def test_matmul_clipping_correctness_unified():
+    """Test that original and fused produce identical results with 
per-iteration clipping."""
+    A_np = np.random.randn(16, 16).astype("float32")
+    B_np = np.random.randn(16, 16).astype("float32")
+    lower = -5.0
+    upper = 5.0
+
+    # NumPy reference for per-iteration clipping
+    D_ref = np.clip(0.0, lower, upper)  # init with clipping
+    for k in range(16):
+        D_ref = np.clip(D_ref + np.outer(A_np[:, k], B_np[:, k]), lower, upper)
+
+    # TVM execution (original with per-iteration clipping)
+    mod_original = tvm.compile(matmul_clipping_before_per_iteration, 
target="llvm")
+    D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32"))
+    mod_original(
+        tvm.runtime.tensor(A_np),
+        tvm.runtime.tensor(B_np),
+        D_original_tvm,
+    )
+
+    # TVM execution (fused)
+    sch = tir.Schedule(matmul_clipping_before)
+    sch.fuse_reduction_epilogue("matmul", "clipping")
+    mod_fused = tvm.compile(sch.mod["main"], target="llvm")
+    D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32"))
+    # Pass scalar values directly as Python floats
+    mod_fused(
+        tvm.runtime.tensor(A_np),
+        tvm.runtime.tensor(B_np),
+        D_fused_tvm,
+        lower,
+        upper,
+    )
+
+    D_original = D_original_tvm.numpy()
+    D_fused = D_fused_tvm.numpy()
+
+    # Now both should match exactly
+    np.testing.assert_allclose(D_original, D_ref, rtol=1e-5, atol=1e-6)
+    np.testing.assert_allclose(D_fused, D_ref, rtol=1e-5, atol=1e-6)
+    np.testing.assert_allclose(D_original, D_fused, rtol=1e-5, atol=1e-6)
+
+
[email protected]_func
+def matmul_clipping_multiple_epilogue_before(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+    E: T.Buffer((16, 16), "float32"),
+    lower: T.float32,
+    upper: T.float32,
+) -> None:
+    """Original function with separate reduction and multiple epilogue blocks 
(one with clipping, one without)."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                temp[vi, vj] = T.float32(0)
+            temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+    for i, j in T.grid(16, 16):
+        with T.block("clipping"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            D[vi, vj] = T.min(T.max(temp[vi, vj], lower), upper)
+
+    for i, j in T.grid(16, 16):
+        with T.block("copy"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            E[vi, vj] = temp[vi, vj]
+
+
[email protected]_func
+def matmul_clipping_multiple_epilogue_expected(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+    E: T.Buffer((16, 16), "float32"),
+    lower: T.float32,
+    upper: T.float32,
+) -> None:
+    """Expected function after fusion (Clipping) with multiple epilogue 
blocks."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            T.reads(A[vi, vk], B[vj, vk])
+            T.writes(D[vi, vj])
+            with T.init():
+                D[vi, vj] = T.min(T.max(T.float32(0), lower), upper)
+            D[vi, vj] = T.min(T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], lower), 
upper)
+    for i, j in T.grid(16, 16):
+        with T.block("copy"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            T.reads(temp[vi, vj])
+            T.writes(E[vi, vj])
+            E[vi, vj] = temp[vi, vj]
+
+
+def test_matmul_clipping_multiple_epilogue():
+    """Test fusion with multiple epilogue blocks - one with clipping, one 
without.
+
+    Following the same pattern as 
test_fuse_reduction_epilogue_multiple_epilogue,
+    this test verifies that fusion works correctly when there are multiple
+    epilogue blocks. The temp buffer is kept because the second epilogue block
+    still needs it.
+    """
+    sch = tir.Schedule(matmul_clipping_multiple_epilogue_before, 
debug_mask="all")
+    sch.fuse_reduction_epilogue("matmul", "clipping")
+    assert_structural_equal_ignore_global_symbol(
+        sch.mod["main"], matmul_clipping_multiple_epilogue_expected
+    )
+    verify_trace_roundtrip(sch=sch, 
mod=matmul_clipping_multiple_epilogue_before)
+
+    mod = tvm.compile(sch.mod["main"], target="llvm")
+    assert mod is not None
+
+
+# Test commutative variants of clipping patterns
[email protected](
+    "pattern_func",
+    [
+        lambda temp, lower, upper: T.min(T.max(temp, lower), upper),  # 
min(max(temp, lower), upper)
+        lambda temp, lower, upper: T.min(upper, T.max(temp, lower)),  # 
min(upper, max(temp, lower))
+        lambda temp, lower, upper: T.min(T.max(lower, temp), upper),  # 
min(max(lower, temp), upper)
+        lambda temp, lower, upper: T.max(T.min(temp, upper), lower),  # 
max(min(temp, upper), lower)
+        lambda temp, lower, upper: T.max(lower, T.min(temp, upper)),  # 
max(lower, min(temp, upper))
+    ],
+)
+def test_matmul_clipping_commutative_variants(pattern_func):
+    """Test that all commutative variants of clipping patterns are 
recognized."""
+    lower = -5.0
+    upper = 5.0
+
+    @T.prim_func
+    def test_func(
+        A: T.Buffer((8, 8), "float32"),
+        B: T.Buffer((8, 8), "float32"),
+        D: T.Buffer((8, 8), "float32"),
+    ) -> None:
+        temp = T.alloc_buffer((8, 8), dtype="float32")
+        for i, j, k in T.grid(8, 8, 8):
+            with T.block("matmul"):
+                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                with T.init():
+                    temp[vi, vj] = T.float32(0)
+                temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+        for i, j in T.grid(8, 8):
+            with T.block("clipping"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                D[vi, vj] = pattern_func(temp[vi, vj], T.float32(lower), 
T.float32(upper))
+
+    sch = tir.Schedule(test_func, debug_mask="all")
+    # Should not raise an error - all variants should be recognized
+    sch.fuse_reduction_epilogue("matmul", "clipping")
+    verify_trace_roundtrip(sch=sch, mod=test_func)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git 
a/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py 
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py
new file mode 100644
index 0000000000..66e5e52e43
--- /dev/null
+++ 
b/tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py
@@ -0,0 +1,229 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import (
+    verify_trace_roundtrip,
+    assert_structural_equal_ignore_global_symbol,
+)
+import numpy as np
+
+# pylint: disable=no-member,invalid-name,unused-variable
+
+
[email protected]_func
+def matmul_bias_relu_before(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    C: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+) -> None:
+    """Original function with separate reduction and epilogue blocks (Bias + 
ReLU)."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                temp[vi, vj] = T.float32(0)
+            temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+    for i, j in T.grid(16, 16):
+        with T.block("bias_relu"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0))
+
+
[email protected]_func
+def matmul_bias_relu_before_per_iteration(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    C: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+) -> None:
+    """Original function with per-iteration ReLU (same semantics as fused)."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j in T.grid(16, 16):
+        with T.block("init"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            temp[vi, vj] = T.max(C[vi, vj], T.float32(0))  # ReLU on bias
+
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            # Per-iteration ReLU
+            temp[vi, vj] = T.max(temp[vi, vj] + A[vi, vk] * B[vj, vk], 
T.float32(0))
+
+    for i, j in T.grid(16, 16):
+        with T.block("copy"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            D[vi, vj] = temp[vi, vj]
+
+
[email protected]_func
+def matmul_bias_relu_expected(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    C: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+) -> None:
+    """Expected function after fusion (Bias + ReLU)."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
+            T.writes(D[vi, vj])
+            with T.init():
+                D[vi, vj] = T.max(C[vi, vj], T.float32(0))
+            D[vi, vj] = T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0))
+
+
+def test_matmul_bias_relu():
+    """Test fusion of matmul with bias + ReLU epilogue."""
+    sch = tir.Schedule(matmul_bias_relu_before, debug_mask="all")
+    sch.fuse_reduction_epilogue("matmul", "bias_relu")
+    assert_structural_equal_ignore_global_symbol(sch.mod["main"], 
matmul_bias_relu_expected)
+    verify_trace_roundtrip(sch=sch, mod=matmul_bias_relu_before)
+
+
+def test_matmul_bias_relu_correctness_unified():
+    """Test that original and fused produce identical results with 
per-iteration ReLU."""
+    A_np = np.random.randn(16, 16).astype("float32")
+    B_np = np.random.randn(16, 16).astype("float32")
+    C_np = np.random.randn(16, 16).astype("float32")
+
+    # NumPy reference for per-iteration ReLU
+    # Simulate per-iteration ReLU behavior
+    # Original code computes A[vi, vk] * B[vj, vk] which is A[i, k] * B[j, k]
+    # For each k: add outer product of A[:, k] and B[:, k]
+    D_ref = np.maximum(C_np, 0)  # init with ReLU on bias
+    for k in range(16):
+        # A[:, k] is shape (16,), B[:, k] is shape (16,)
+        # Outer product: A[:, k] * B[:, k] for all i, j = A[i, k] * B[j, k]
+        # Using broadcasting: A[:, k:k+1] * B[:, k:k+1].T gives (16, 1) * (1, 
16) = (16, 16)
+        D_ref = np.maximum(D_ref + np.outer(A_np[:, k], B_np[:, k]), 0)
+
+    # TVM execution (original with per-iteration ReLU)
+    mod_original = tvm.compile(matmul_bias_relu_before_per_iteration, 
target="llvm")
+    D_original_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32"))
+    mod_original(
+        tvm.runtime.tensor(A_np),
+        tvm.runtime.tensor(B_np),
+        tvm.runtime.tensor(C_np),
+        D_original_tvm,
+    )
+
+    # TVM execution (fused)
+    sch = tir.Schedule(matmul_bias_relu_before)
+    sch.fuse_reduction_epilogue("matmul", "bias_relu")
+    mod_fused = tvm.compile(sch.mod["main"], target="llvm")
+    D_fused_tvm = tvm.runtime.tensor(np.zeros((16, 16), dtype="float32"))
+    mod_fused(
+        tvm.runtime.tensor(A_np),
+        tvm.runtime.tensor(B_np),
+        tvm.runtime.tensor(C_np),
+        D_fused_tvm,
+    )
+
+    D_original = D_original_tvm.numpy()
+    D_fused = D_fused_tvm.numpy()
+
+    # Now both should match exactly
+    np.testing.assert_allclose(D_original, D_ref, rtol=1e-5, atol=1e-6)
+    np.testing.assert_allclose(D_fused, D_ref, rtol=1e-5, atol=1e-6)
+    np.testing.assert_allclose(D_original, D_fused, rtol=1e-5, atol=1e-6)
+
+
[email protected]_func
+def matmul_bias_relu_multiple_epilogue_before(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    C: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+    E: T.Buffer((16, 16), "float32"),
+) -> None:
+    """Original function with separate reduction and multiple epilogue blocks 
(one with ReLU, one without)."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                temp[vi, vj] = T.float32(0)
+            temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vj, vk]
+
+    for i, j in T.grid(16, 16):
+        with T.block("bias_relu"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            D[vi, vj] = T.max(temp[vi, vj] + C[vi, vj], T.float32(0))
+
+    for i, j in T.grid(16, 16):
+        with T.block("bias"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            E[vi, vj] = temp[vi, vj] + C[vi, vj]
+
+
[email protected]_func
+def matmul_bias_relu_multiple_epilogue_expected(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    C: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+    E: T.Buffer((16, 16), "float32"),
+) -> None:
+    """Expected function after fusion (Bias + ReLU) with multiple epilogue 
blocks."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
+            T.writes(D[vi, vj])
+            with T.init():
+                D[vi, vj] = T.max(C[vi, vj], T.float32(0))
+            D[vi, vj] = T.max(D[vi, vj] + A[vi, vk] * B[vj, vk], T.float32(0))
+    for i, j in T.grid(16, 16):
+        with T.block("bias"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            T.reads(temp[vi, vj], C[vi, vj])
+            T.writes(E[vi, vj])
+            E[vi, vj] = temp[vi, vj] + C[vi, vj]
+
+
+def test_matmul_bias_relu_multiple_epilogue():
+    """Test fusion with multiple epilogue blocks - one with ReLU, one without.
+
+    Following the same pattern as 
test_fuse_reduction_epilogue_multiple_epilogue,
+    this test verifies that fusion works correctly when there are multiple
+    epilogue blocks. The temp buffer is kept because the second epilogue block
+    still needs it.
+    """
+    sch = tir.Schedule(matmul_bias_relu_multiple_epilogue_before, 
debug_mask="all")
+    sch.fuse_reduction_epilogue("matmul", "bias_relu")
+    assert_structural_equal_ignore_global_symbol(
+        sch.mod["main"], matmul_bias_relu_multiple_epilogue_expected
+    )
+    verify_trace_roundtrip(sch=sch, 
mod=matmul_bias_relu_multiple_epilogue_before)
+
+    mod = tvm.compile(sch.mod["main"], target="llvm")
+    assert mod is not None
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to