This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 70c157d6ca [Analyzer] Enhance ConstIntBoundAnalyzer and IntervalSet 
with modular set analysis (#18330)
70c157d6ca is described below

commit 70c157d6cad0b76a9254fc644e6aefe043570b18
Author: Lei Wang <[email protected]>
AuthorDate: Sat Oct 18 20:06:56 2025 +0800

    [Analyzer] Enhance ConstIntBoundAnalyzer and IntervalSet with modular set 
analysis (#18330)
    
    * Enhance ConstIntBoundAnalyzer and IntervalSet with modular set analysis
    
    - Added modular set analysis to ConstIntBoundAnalyzer for tighter bounds 
when min_value equals max_value.
    - Introduced ComputeGCD function to calculate the GCD of two integers.
    - Updated Combine functions in IntervalSet to accept operation nodes for 
better type handling.
    - Enhanced tests for modular set bounds in both const integer bounds and 
interval sets.
    
    * replace gcd compute with ZeroAwareGCD
    
    * doc op node
    
    * replace Compute GCD with ZeroAwareGCD
    
    * add example
    
    * test fix
    
    * test fix
    
    * lint fix
---
 src/arith/const_int_bound.cc                       | 59 ++++++++++++++-
 src/arith/int_set.cc                               | 76 ++++++++++++++-----
 tests/python/arith/test_arith_const_int_bound.py   | 12 +++
 tests/python/arith/test_arith_intset.py            | 10 +++
 ...schedule_feature_extractor_per_store_feature.py | 88 +++++++++++-----------
 tests/python/te/test_te_create_primfunc.py         |  4 +-
 6 files changed, 182 insertions(+), 67 deletions(-)

diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index b8e5db483f..7e1d8fb3fb 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -102,6 +102,7 @@ struct ConstIntBoundAnalyzer::Entry {
 class ConstIntBoundAnalyzer::Impl
     : public ExprFunctor<ConstIntBoundAnalyzer::Entry(const PrimExpr&)> {
  public:
+  explicit Impl(Analyzer* parent) : parent_(parent) {}
   /*! \brief additional bound info about expr in bound */
   struct BoundInfo {
     /*! \brief The expr */
@@ -278,6 +279,33 @@ class ConstIntBoundAnalyzer::Impl
 
     if (b.min_value > 0) {
       int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
+
+      // Try to get tighter bounds using modular set information
+      if (parent_ && b.min_value == b.max_value) {
+        ModularSet mod_a = parent_->modular_set(op->a);
+        int64_t modulus = b.min_value;
+        int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus);
+
+        // If gcd_coeff_mod > 1, we can get tighter bounds
+        // The result will be of the form gcd_coeff_mod * k + (base % modulus)
+        // where k ranges to cover [0, modulus - gcd_coeff_mod]
+        //
+        // Example: expr = (bx * 2048 + tx * 16) % 7168
+        //          where bx in [0, 3584), tx in [0, 128)
+        //          ModularSet(expr) = 16*k (coeff=16, base=0)
+        //          GCD(16, 7168) = 16
+        //          Result can only be {0, 16, 32, ..., 7152}
+        //          Without this optimization: bound = [0, 7167]
+        //          With this optimization: bound = [0, 7152]
+        if (gcd_coeff_mod > 1) {
+          int64_t base_mod = mod_a->base % modulus;
+          if (base_mod < 0) base_mod += modulus;
+          int64_t tight_max = modulus - gcd_coeff_mod + base_mod;
+          if (tight_max >= modulus) tight_max -= modulus;
+          return MakeBound(base_mod, tight_max);
+        }
+      }
+
       if (a.min_value >= 0) {
         // 0 <= [a_min, a_max] < b_min
         if (a.max_value < b.min_value) return a;
@@ -324,6 +352,32 @@ class ConstIntBoundAnalyzer::Impl
 
     if (b.min_value > 0) {
       int64_t b_max_cap = InfAwareAdd(b.max_value, -1);
+      // Try to get tighter bounds using modular set information
+      if (parent_ && b.min_value == b.max_value) {
+        ModularSet mod_a = parent_->modular_set(op->a);
+        int64_t modulus = b.min_value;
+        int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus);
+
+        // If gcd_coeff_mod > 1, we can get tighter bounds
+        // The result will be of the form gcd_coeff_mod * k + (base % modulus)
+        // where k ranges to cover [0, modulus - gcd_coeff_mod]
+        //
+        // Example: expr = (bx * 2048 + tx * 16) % 7168
+        //          where bx in [0, 3584), tx in [0, 128)
+        //          ModularSet(expr) = 16*k (coeff=16, base=0)
+        //          GCD(16, 7168) = 16
+        //          Result can only be {0, 16, 32, ..., 7152}
+        //          Without this optimization: bound = [0, 7167]
+        //          With this optimization: bound = [0, 7152]
+        if (gcd_coeff_mod > 1) {
+          int64_t base_mod = mod_a->base % modulus;
+          if (base_mod < 0) base_mod += modulus;
+          int64_t tight_max = modulus - gcd_coeff_mod + base_mod;
+          if (tight_max >= modulus) tight_max -= modulus;
+          return MakeBound(base_mod, tight_max);
+        }
+      }
+
       if (a.min_value >= 0) {
         // 0 <= [a_min, a_max] < b_min
         if (a.max_value < b.min_value) return a;
@@ -458,6 +512,8 @@ class ConstIntBoundAnalyzer::Impl
 
  private:
   friend class ConstIntBoundAnalyzer;
+  // parent analyzer
+  Analyzer* parent_;
   // internal variable map
   std::unordered_map<Var, Entry> var_map_;
   // additional bound info
@@ -525,6 +581,7 @@ class ConstIntBoundAnalyzer::Impl
     // If the range of b does not have 0, use BinaryOpBoundary.
     return BinaryOpBoundary(a, b, op);
   }
+
   /*!
    * \brief Compute x + y, aware of inf.
    * \param x The left operand.
@@ -805,7 +862,7 @@ std::function<void()> 
ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& con
   return impl_->EnterConstraint(constraint);
 }
 
-ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new 
Impl()) {}
+ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new 
Impl(parent)) {}
 
 ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; }
 
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index aa15284b3e..1433ceb70f 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -27,12 +27,14 @@
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
 
 #include <algorithm>
 #include <unordered_map>
 #include <utility>
 
 #include "constraint_extract.h"
+#include "int_operator.h"
 #include "interval_set.h"
 #include "pattern_match.h"
 
@@ -109,10 +111,15 @@ TVM_DECLARE_LOGICAL_OP(Not);
 
 /*!
  * \brief Combine two interval set under arithmetic operations.
+ * \param analyzer The analyzer for simplification and proving
+ * \param a The first interval set
+ * \param b The second interval set
+ * \param op The operation node, used to extract dtype and other properties
  * \note this can possibly relax the set.
  */
-template <typename Op>
-inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, 
DataType dtype) {
+template <typename Op, typename OpNode>
+inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, 
const OpNode* op) {
+  DataType dtype = op->dtype;
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     PrimExpr expr;
     if (auto res = TryConstFold<Op>(a->min_value, b->min_value)) {
@@ -134,7 +141,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet 
a, IntervalSet b, Dat
 
 template <>
 inline IntervalSet Combine<tir::Add>(Analyzer* analyer, IntervalSet a, 
IntervalSet b,
-                                     DataType /* dtype */) {
+                                     const tir::AddNode* /* op */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value + b->min_value);
   }
@@ -149,7 +156,7 @@ inline IntervalSet Combine<tir::Add>(Analyzer* analyer, 
IntervalSet a, IntervalS
 
 template <>
 inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, IntervalSet a, 
IntervalSet b,
-                                     DataType /* dtype */) {
+                                     const tir::SubNode* /* op */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value - b->min_value);
   }
@@ -164,7 +171,7 @@ inline IntervalSet Combine<tir::Sub>(Analyzer* analyer, 
IntervalSet a, IntervalS
 
 template <>
 inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
-                                     DataType /* dtype */) {
+                                     const tir::MulNode* /* op */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value * b->min_value);
   }
@@ -198,7 +205,7 @@ inline IntervalSet Combine<tir::Mul>(Analyzer* analyzer, 
IntervalSet a, Interval
 
 template <>
 inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
-                                     DataType /* dtype */) {
+                                     const tir::DivNode* /* op */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(a->min_value / b->min_value);
   }
@@ -232,7 +239,7 @@ inline IntervalSet Combine<tir::Div>(Analyzer* analyzer, 
IntervalSet a, Interval
 
 template <>
 inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
-                                     DataType /* dtype */) {
+                                     const tir::ModNode* op) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value));
   }
@@ -261,7 +268,7 @@ inline IntervalSet Combine<tir::Mod>(Analyzer* analyzer, 
IntervalSet a, Interval
 
 template <>
 inline IntervalSet Combine<tir::FloorDiv>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
-                                          DataType /* dtype */) {
+                                          const tir::FloorDivNode* /* op */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value));
   }
@@ -295,7 +302,7 @@ inline IntervalSet Combine<tir::FloorDiv>(Analyzer* 
analyzer, IntervalSet a, Int
 
 template <>
 inline IntervalSet Combine<tir::FloorMod>(Analyzer* analyzer, IntervalSet a, 
IntervalSet b,
-                                          DataType /* dtype */) {
+                                          const tir::FloorModNode* op) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value));
   }
@@ -321,6 +328,29 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* 
analyzer, IntervalSet a, Int
           return IntervalSet(tmin, tmax);
         }
       }
+      // Enhanced: Use ModularSet analysis for better bounds
+      if (auto* div_imm = divisor.as<tir::IntImmNode>()) {
+        int64_t div_val = div_imm->value;
+
+        // Analyze the modular properties of the dividend
+        ModularSet dividend_mod = analyzer->modular_set(op->a);
+
+        if (dividend_mod.defined() && dividend_mod->coeff > 0) {
+          // Calculate GCD of dividend coefficient and divisor
+          int64_t gcd = ZeroAwareGCD(dividend_mod->coeff, div_val);
+
+          if (gcd > 1 && div_val % gcd == 0) {
+            // The dividend is a multiple of gcd, and divisor is also a 
multiple of gcd
+            // So the result is also a multiple of gcd, with max value = 
(div_val/gcd - 1) * gcd
+            int64_t max_quotient = (div_val / gcd) - 1;
+            int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base 
% gcd);
+
+            if (max_mod_result >= 0 && max_mod_result < div_val) {
+              return IntervalSet(make_zero(op->dtype), make_const(op->dtype, 
max_mod_result));
+            }
+          }
+        }
+      }
       return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
     } else {
       PrimExpr bound = abs(divisor) - 1;
@@ -333,7 +363,7 @@ inline IntervalSet Combine<tir::FloorMod>(Analyzer* 
analyzer, IntervalSet a, Int
 
 template <>
 inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, IntervalSet a, 
IntervalSet b,
-                                     DataType /* dtype */) {
+                                     const tir::MaxNode* /* op */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(max(a->min_value, b->min_value));
   }
@@ -344,7 +374,7 @@ inline IntervalSet Combine<tir::Max>(Analyzer* analzyer, 
IntervalSet a, Interval
 
 template <>
 inline IntervalSet Combine<tir::Min>(Analyzer* analzyer, IntervalSet a, 
IntervalSet b,
-                                     DataType /* dtype */) {
+                                     const tir::MinNode* /* op */) {
   if (a->IsSinglePoint() && b->IsSinglePoint()) {
     return IntervalSet::SinglePoint(min(a->min_value, b->min_value));
   }
@@ -475,19 +505,25 @@ class IntervalSetEvaluator : public 
ExprFunctor<IntervalSet(const PrimExpr&)> {
       if (op->lanes->IsInstance<IntImmNode>()) {
         int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
         if (vstride > 0) {
-          return Combine<Add>(analyzer_, base,
-                              IntervalSet(make_zero(t), make_const(t, vstride 
* (lanes - 1))),
-                              op->dtype);
+          PrimExpr stride_expr = make_const(t, vstride * (lanes - 1));
+          auto add_op = tir::Add(op->base, stride_expr);
+          auto add_node = add_op.as<tir::AddNode>();
+          return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t), 
stride_expr), add_node);
         } else {
-          return Combine<Add>(analyzer_, base,
-                              IntervalSet(make_const(t, vstride * (lanes - 
1)), make_zero(t)),
-                              op->dtype);
+          PrimExpr stride_expr = make_const(t, vstride * (lanes - 1));
+          auto add_op = tir::Add(op->base, stride_expr);
+          auto add_node = add_op.as<tir::AddNode>();
+          return Combine<Add>(analyzer_, base, IntervalSet(stride_expr, 
make_zero(t)), add_node);
         }
       } else { /* Scalable vector */
         if (vstride > 0) {
-          return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t), 
pos_inf()), op->dtype);
+          auto add_op = tir::Add(op->base, make_zero(t));
+          auto add_node = add_op.as<tir::AddNode>();
+          return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t), 
pos_inf()), add_node);
         } else {
-          return Combine<Add>(analyzer_, base, IntervalSet(neg_inf(), 
make_zero(t)), op->dtype);
+          auto add_op = tir::Add(op->base, make_zero(t));
+          auto add_node = add_op.as<tir::AddNode>();
+          return Combine<Add>(analyzer_, base, IntervalSet(neg_inf(), 
make_zero(t)), add_node);
         }
       }
     }
@@ -563,7 +599,7 @@ class IntervalSetEvaluator : public 
ExprFunctor<IntervalSet(const PrimExpr&)> {
     if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) {
       return IntervalSet::SinglePoint(ffi::GetRef<PrimExpr>(op));
     }
-    return Combine<TOp>(analyzer_, a, b, op->dtype);
+    return Combine<TOp>(analyzer_, a, b, op);
   }
 
   // recursive depth
diff --git a/tests/python/arith/test_arith_const_int_bound.py 
b/tests/python/arith/test_arith_const_int_bound.py
index 14bfec2328..8728df7e3f 100644
--- a/tests/python/arith/test_arith_const_int_bound.py
+++ b/tests/python/arith/test_arith_const_int_bound.py
@@ -298,5 +298,17 @@ class TestRampBound(BaseCompare):
     )
 
 
+class TestModularSetBound(BaseCompare):
+    analyzer = tvm.arith.Analyzer()
+    tx = tvm.te.var("tx", dtype="int32")
+    bx = tvm.te.var("bx", dtype="int32")
+
+    expr = (bx * 2048 + tx * 16) % 7168
+
+    test_case = tvm.testing.parameter(
+        TestCase(expr, (0, 7152), {bx: (0, 3584), tx: (0, 128)}),
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/arith/test_arith_intset.py 
b/tests/python/arith/test_arith_intset.py
index 18865a73df..04014ca300 100644
--- a/tests/python/arith/test_arith_intset.py
+++ b/tests/python/arith/test_arith_intset.py
@@ -387,5 +387,15 @@ def test_union_lower_bound():
     assert result.max_value.same_as(pos_inf)
 
 
+def test_modular_set():
+    ck = IntSetChecker()
+    x = tvm.te.var("x", dtype="int32")
+    y = tvm.te.var("y", dtype="int32")
+    expr = (x * 2048 + y * 16) % 7168
+    ck.verify(
+        expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 
3584)}, (0, 7152)
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py
 
b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py
index 057cd0e9f7..b901c3ce13 100644
--- 
a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py
+++ 
b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py
@@ -846,21 +846,21 @@ def test_gpu():
             1.0,
             0.0,
             0.0,
-            25.000000042995662,
-            20.000001375860553,
-            23.00000017198264,
-            14.000088052430122,
+            25.00000004,
+            19.99718086,
+            23.00000017,
+            13.99726771,
             1.0,
             0.0,
             0.0,
-            18.00000550343433,
-            20.00562591970089,
-            2.321928094887362,
-            23.00000017198264,
-            18.00000550343433,
-            21.000000687930438,
-            12.0003521774803,
-            12.0003521774803,
+            18.0000055,
+            20.00000138,
+            2.32192809,
+            23.00000017,
+            17.997185,
+            21.00000069,
+            11.99753235,
+            12.00035218,
         ],
         rtol=1e-5,
         atol=1e-5,
@@ -872,21 +872,21 @@ def test_gpu():
             0.0,
             1.0,
             0.0,
-            25.000000042995662,
-            12.0003521774803,
-            23.00000017198264,
-            9.002815015607053,
+            25.00000004,
+            11.00070427,
+            23.00000017,
+            5.04439412,
             1.0,
             0.0,
             0.0,
-            6.022367813028454,
-            11.98049663618346,
-            8.005624549193879,
-            17.000011006847668,
-            4.087462841250339,
-            15.000044026886828,
-            1.584962500721156,
-            4.087462841250339,
+            6.02236781,
+            11.98049664,
+            8.00562455,
+            17.00001101,
+            3.169925,
+            15.00004403,
+            0.169925,
+            4.08746284,
         ],
         rtol=1e-5,
         atol=1e-5,
@@ -1052,21 +1052,21 @@ def test_gpu():
             1.0,
             0.0,
             0.0,
-            22.00000034396526,
-            20.000001375860553,
-            20.000001375860553,
-            14.000088052430122,
+            22.00000034,
+            19.85798251,
+            20.00000138,
+            13.85807816,
             1.0,
             0.0,
             0.0,
-            15.000044026886828,
-            20.17555076886471,
-            2.321928094887362,
-            20.000001375860553,
-            18.00000550343433,
-            18.00000550343433,
-            12.0003521774803,
-            4.087462841250339,
+            15.00004403,
+            20.04456622,
+            2.32192809,
+            20.00000138,
+            17.85798707,
+            18.0000055,
+            11.8583696,
+            4.08746284,
         ],
         rtol=1e-5,
         atol=1e-5,
@@ -1078,20 +1078,20 @@ def test_gpu():
             0.0,
             1.0,
             0.0,
-            22.00000034396526,
-            9.002815015607053,
-            20.000001375860553,
-            3.169925001442312,
+            22.00000034,
+            7.01122726,
+            20.00000138,
+            4.08746284,
             1.0,
             0.0,
             0.0,
             3.169925001442312,
-            9.61654884377899,
+            4.08746284,
             8.005624549193879,
             14.000088052430122,
-            1.584962500721156,
-            12.0003521774803,
-            0.044394119358453436,
+            0.5849625,
+            12.00035218,
+            0.08746284,
             4.087462841250339,
         ],
         rtol=1e-5,
diff --git a/tests/python/te/test_te_create_primfunc.py 
b/tests/python/te/test_te_create_primfunc.py
index c8a0952802..426272584b 100644
--- a/tests/python/te/test_te_create_primfunc.py
+++ b/tests/python/te/test_te_create_primfunc.py
@@ -852,7 +852,7 @@ def test_adaptive_pooling_window():
                 v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, 
ax2, ax3])
                 T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + 
((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 
3 * 10 + 40) // 30 + 1)])
                 T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
-                for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0, 
(v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, 
T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 
40) // 30 + 1) - v_ax3 * 40 // 30):
+                for rv0, rv1 in T.grid((v_ax2 % 3 * 4 + 16) // 12 + 1, (v_ax3 
% 3 * 10 + 40) // 30 + 1):
                     with T.block("adaptive_pool_sum"):
                         v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0)
                         v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1)
@@ -870,7 +870,7 @@ def test_adaptive_pooling_window():
                 T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                 T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                 T.block_attr({"schedule_rule": 
"meta_schedule.adaptive_pool_avg"})
-                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = 
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", 
T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) 
// 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 
30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 
30))
+                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = 
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", (v_ax2 % 3 * 
4 + 16) // 12 + 1) * T.Cast("float32", (v_ax3 % 3 * 10 + 40) // 30 + 1))
         # fmt: on
 
     def te_workload():

Reply via email to