This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 60ed926  [ARITH] Introduce iterator (quasi)affine map detection. 
(#6667)
60ed926 is described below

commit 60ed9261058c0f1faa2632cb97319f29b26b9573
Author: Tianqi Chen <tqc...@users.noreply.github.com>
AuthorDate: Wed Oct 14 08:33:15 2020 -0400

    [ARITH] Introduce iterator (quasi)affine map detection. (#6667)
    
    * [ARITH] Introduce iterator (quasi)affine map detection.
    
    The loop transformations (split, fuse) create bijective
    maps from a collection of source iterators to target iterators.
    
    DetectIterMap is a function that detects such bijective mappings
    from the lowered index expression.
    
    We choose the term quasi affine to be consistent with the
    terminology used by in polyhedral compilation.
    DetectIterMap can handle symbolic integers(in split/fuse) to some extent.
    
    The utility can be useful in detecting loop transformation
    patterns and data layout change patterns in TIR.
    
    * Update per feedback
---
 include/tvm/arith/iter_affine_map.h                | 277 ++++++++
 python/tvm/arith/__init__.py                       |   2 +
 python/tvm/arith/iter_affine_map.py                | 108 ++++
 src/arith/iter_affine_map.cc                       | 717 +++++++++++++++++++++
 src/arith/rewrite_simplify.cc                      |   3 +
 src/node/structural_hash.cc                        |  18 +-
 src/support/util.h                                 |  10 +
 .../python/unittest/test_arith_iter_affine_map.py  | 176 +++++
 8 files changed, 1298 insertions(+), 13 deletions(-)

diff --git a/include/tvm/arith/iter_affine_map.h 
b/include/tvm/arith/iter_affine_map.h
new file mode 100644
index 0000000..00f8cf6
--- /dev/null
+++ b/include/tvm/arith/iter_affine_map.h
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/iter_affine_map.h
+ * \brief Iterator quasi-affine mapping patterns.
+ *
+ *  This file defines a collection of mapping patterns
+ *  maps a collection of independent iterators to another
+ *  collection of independent iterators.
+ *
+ *  There are two main kinds of mapping patterns:
+ *
+ *  - Fuse: fuse a collection of iterators into a single one
+ *
+ *    domain(x0) = [0, 4), domain(x1) = [0, 3), domain(x2) = [0, 2)
+ *    fuse(x0, x1, x2): y = x2 * 12 + x1 * 4 + x0
+ *    domain(y) = [0, 24)
+ *
+ *  - Split: split an iterator into multiple ones
+ *
+ *    domain(x) = [0, 24)
+ *    split(x, 3, 12): [y0, y1, y2] = [x % 3, (x % 12) / 3, x / 12]
+ *    domain(y0) = [0, 3), domain(y1) = [0, 4), domain(y2) = [0, 2)
+ *
+ *  We use the name "(quasi)affine" to be consistent with
+ *  the terminology used in the polyhedral compilation.
+ *  Notably, fuse is an affine transformation,
+ *  while split corresponds to additional floordiv/mod operations
+ *  that can appear in quasi-affine transformations.
+ */
+#ifndef TVM_ARITH_ITER_AFFINE_MAP_H_
+#define TVM_ARITH_ITER_AFFINE_MAP_H_
+
+#include <tvm/ir/expr.h>
+
+namespace tvm {
+namespace arith {
+
+/*!
+ * \brief Base class of all iter map expressions.
+ *
+ *  An IterMapExpr is a special expression to store
+ *  the result of IterMapDetection.
+ *  It should not appear in a legal TIR PrimFunc.
+ */
+class IterMapExprNode : public PrimExprNode {
+ public:
+  // overrides
+  void VisitAttrs(tvm::AttrVisitor* v) {}
+
+  static constexpr const char* _type_key = "arith.IterMapExpr";
+  static constexpr const uint32_t _type_child_slots = 3;
+  TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
+};
+
+/*!
+ * \brief Managed reference to IterMapExprNode.
+ * \sa IterMapExprNode
+ */
+class IterMapExpr : public PrimExpr {
+ public:
+  TVM_DEFINE_OBJECT_REF_METHODS(IterMapExpr, PrimExpr, IterMapExprNode);
+};
+
+/*!
+ * \brief Mark the source as an iterator in [0, extent).
+ *
+ *  IterMark is used to mark source expression as a valid
+ *  iterator to make future analysis easy.
+ */
+class IterMarkNode : public Object {
+ public:
+  /*!
+   * \brief The source expression, can either be
+   *  a IterSumExpr or a Var.
+   */
+  PrimExpr source;
+  /*!
+   * \brief The extent of the iteration.
+   */
+  PrimExpr extent;
+
+  // overrides
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("source", &source);
+    v->Visit("extent", &extent);
+  }
+
+  bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const {
+    equal->MarkGraphNode();
+    return equal(source, other->source) && equal(extent, other->extent);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce->MarkGraphNode();
+    hash_reduce(source);
+    hash_reduce(extent);
+  }
+
+  static constexpr const bool _type_has_method_sequal_reduce = true;
+  static constexpr const bool _type_has_method_shash_reduce = true;
+  static constexpr const char* _type_key = "arith.IterMark";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object);
+};
+
+/*!
+ * \brief Managed reference to IterMarkExprNode.
+ * \sa IterMarkExprNode
+ */
+class IterMark : public ObjectRef {
+ public:
+  /*!
+   * \brief constructor.
+   * \param source The source expression.
+   * \param extent The extent of the iterator.
+   */
+  TVM_DLL IterMark(PrimExpr source, PrimExpr extent);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode);
+};
+
+/*!
+ * \brief Split of an iterator.
+ *
+ *  result = floormod(floordiv(source, lower_factor), extent) * scale
+ */
+class IterSplitExprNode : public IterMapExprNode {
+ public:
+  /*! \brief The source marked iterator. */
+  IterMark source;
+  /*! \brief The lower factor to split the source. */
+  PrimExpr lower_factor;
+  /*! \brief The extent of the split. */
+  PrimExpr extent;
+  /*! \brief Additional scale. */
+  PrimExpr scale;
+
+  // overrides
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("source", &source);
+    v->Visit("lower_factor", &lower_factor);
+    v->Visit("extent", &extent);
+    v->Visit("scale", &scale);
+  }
+
+  bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const 
{
+    return equal(source, other->source) && equal(lower_factor, 
other->lower_factor) &&
+           equal(extent, other->extent) && equal(scale, other->scale);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(source);
+    hash_reduce(lower_factor);
+    hash_reduce(extent);
+    hash_reduce(scale);
+  }
+
+  static constexpr const char* _type_key = "arith.IterSplitExpr";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode);
+};
+
+/*!
+ * \brief Managed reference to IterSplitExprNode.
+ * \sa IterSplitExprNode
+ */
+class IterSplitExpr : public IterMapExpr {
+ public:
+  /*!
+   * \brief constructor from just source.
+   * \param source The source expression.
+   */
+  TVM_DLL explicit IterSplitExpr(IterMark source);
+  /*!
+   * \brief constructor
+   * \param source The source expression.
+   * \param lower_factor The lower factor to split the source.
+   * \param extent The extent of the split.
+   * \param scale The additional scaling factor.
+   */
+  TVM_DLL explicit IterSplitExpr(IterMark source, PrimExpr lower_factor, 
PrimExpr extent,
+                                 PrimExpr scale);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IterSplitExpr, IterMapExpr, IterSplitExprNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSplitExprNode);
+};
+
+/*!
+ * \brief Fuse multiple iterators by summing them with scaling.
+ *
+ *  result = sum(args) + base
+ */
+class IterSumExprNode : public IterMapExprNode {
+ public:
+  /*! \brief The args to the sum. */
+  Array<IterSplitExpr> args;
+  /*! \brief The base offset. */
+  PrimExpr base;
+
+  // overrides
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("args", &args);
+    v->Visit("base", &base);
+  }
+
+  bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const {
+    return equal(args, other->args) && equal(base, other->base);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(args);
+    hash_reduce(base);
+  }
+
+  static constexpr const char* _type_key = "arith.IterSumExpr";
+  TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode);
+};
+
+/*!
+ * \brief Managed reference to IterSumExprNode.
+ * \sa IterSumExprNode
+ */
+class IterSumExpr : public IterMapExpr {
+ public:
+  /*!
+   * \brief constructor.
+   * \param args The args to the sum.
+   * \param base The base offset.
+   */
+  TVM_DLL IterSumExpr(Array<IterSplitExpr> args, PrimExpr base);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(IterSumExpr, IterMapExpr, IterSumExprNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode);
+};
+
+/*!
+ * \brief Detect if indices can be written as
+ *
+ *  [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
+ *
+ *  Here y = some-quasi-affine-iter-map(input_iters)
+ *  and c are symbolic constants.
+ *
+ *  We also requires that y_i and y_j to be independent for i != j.
+ *
+ *  For returned value rv, the following is always true:
+ *  - rv[i]->args.size() <=1: only one iterator per element.
+ *
+ * \param indices The indices to detect pattern for.
+ * \param input_iters Map from variable to iterator's range.
+ * \param analyzer Analyzer used to get context information.
+ *
+ * \return The detected pattern if a match exists,
+ *         otherwise return an empty array.
+ */
+Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const 
Map<Var, Range>& input_iters,
+                                 arith::Analyzer* analyzer);
+
+}  // namespace arith
+}  // namespace tvm
+#endif  // TVM_ARITH_ITER_AFFINE_MAP_H_
diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py
index e5af529..77ec869 100644
--- a/python/tvm/arith/__init__.py
+++ b/python/tvm/arith/__init__.py
@@ -21,3 +21,5 @@ from .analyzer import ModularSet, ConstIntBound, Analyzer
 from .bound import deduce_bound
 from .pattern import detect_linear_equation, detect_clip_bound
 from .int_solver import solve_linear_equations, solve_linear_inequalities
+from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
+from .iter_affine_map import detect_iter_map
diff --git a/python/tvm/arith/iter_affine_map.py 
b/python/tvm/arith/iter_affine_map.py
new file mode 100644
index 0000000..123d9b8
--- /dev/null
+++ b/python/tvm/arith/iter_affine_map.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Iterator (quasi)affine mapping patterns."""
+import tvm._ffi
+from tvm.runtime import Object
+from tvm.ir import PrimExpr
+from . import _ffi_api
+
+
+class IterMapExpr(PrimExpr):
+    """Base class of all IterMap expressions."""
+
+
+@tvm._ffi.register_object("arith.IterMark")
+class IterMark(Object):
+    """Mark the source as an iterator in [0, extent).
+
+    Parameters
+    ----------
+    source : PrimExpr.
+        The source expression.
+
+    extent : PrimExpr
+        The extent of the iterator.
+    """
+
+    def __init__(self, source, extent):
+        self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent)
+
+
+@tvm._ffi.register_object("arith.IterSplitExpr")
+class IterSplitExpr(IterMapExpr):
+    """Split of an iterator.
+
+    result = floormod(floordiv(source, lower_factor), extent) * scale
+
+    Parameters
+    ----------
+    source : IterMark
+        The source marked iterator.
+
+    lower_factor : PrimExpr
+        The lower factor to split the domain.
+
+    extent : PrimExpr
+        The extent of the split.
+
+    scale : PrimExpr
+        Additional scale to the split.
+    """
+
+    def __init__(self, source, lower_factor, extent, scale):
+        self.__init_handle_by_constructor__(
+            _ffi_api.IterSplitExpr, source, lower_factor, extent, scale
+        )
+
+
+@tvm._ffi.register_object("arith.IterSumExpr")
+class IterSumExpr(IterMapExpr):
+    """Fuse multiple iterators by summing them with scaling.
+
+    result = sum(args) + base
+
+    Parameters
+    ----------
+    args : List[IterSplitExpr]
+        The input to the sum expression.
+
+    base : PrimExpr
+        The base offset.
+    """
+
+    def __init__(self, args, base):
+        self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)
+
+
+def detect_iter_map(indices, input_iters):
+    """Detect if indices can be written mapped iters from input_iters.
+
+    Parameters
+    ----------
+    indices : List[PrimExpr]
+        The input indices.
+
+    input_iters : Map[Var, Range]
+        The domain of each input iterators.
+
+    Returns
+    -------
+    results : List[IterSumExpr]
+        The iter map matching result.
+        Empty array if no match can be found.
+    """
+    return _ffi_api.DetectIterMap(indices, input_iters)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
new file mode 100644
index 0000000..7afa75a
--- /dev/null
+++ b/src/arith/iter_affine_map.cc
@@ -0,0 +1,717 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/arith/iter_affine_map.cc
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/arith/iter_affine_map.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/op.h>
+
+#include "../support/util.h"
+#include "const_fold.h"
+
+namespace tvm {
+namespace arith {
+
+using namespace tir;
+
+IterMark::IterMark(PrimExpr source, PrimExpr extent) {
+  auto n = make_object<IterMarkNode>();
+  n->source = std::move(source);
+  n->extent = std::move(extent);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, 
PrimExpr extent) {
+  return IterMark(source, extent);
+});
+
+TVM_REGISTER_NODE_TYPE(IterMarkNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterMarkNode*>(node.get());
+      p->stream << "IterMark(" << op->source << ", extent=" << op->extent;
+    });
+
+IterSplitExpr::IterSplitExpr(IterMark source) {
+  auto n = make_object<IterSplitExprNode>();
+  auto one = make_const(source->source->dtype, 1);
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->extent = n->source->extent;
+  n->lower_factor = one;
+  n->scale = one;
+  data_ = std::move(n);
+}
+
+IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr 
extent,
+                             PrimExpr scale) {
+  auto n = make_object<IterSplitExprNode>();
+  n->dtype = source->source->dtype;
+  n->source = std::move(source);
+  n->lower_factor = std::move(lower_factor);
+  n->extent = std::move(extent);
+  n->scale = std::move(scale);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSplitExpr")
+    .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr 
extent, PrimExpr scale) {
+      return IterSplitExpr(source, lower_factor, extent, scale);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSplitExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) 
{
+      auto* op = static_cast<const IterSplitExprNode*>(node.get());
+      p->stream << "IterSplit(" << op->source << ", lower_factor=" << 
op->lower_factor
+                << ", extent=" << op->extent << ", scale=" << op->scale;
+    });
+
+IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) {
+  auto n = make_object<IterSumExprNode>();
+  n->dtype = base->dtype;
+  n->args = std::move(args);
+  n->base = std::move(base);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_GLOBAL("arith.IterSumExpr")
+    .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) {
+      return IterSumExpr(args, base);
+    });
+
+TVM_REGISTER_NODE_TYPE(IterSumExprNode);
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) {
+      auto* op = static_cast<const IterSumExprNode*>(node.get());
+      p->stream << "IterSum(" << op->args << ", " << op->base << ")";
+    });
+
+/*!
+ * \brief Collector that collects
+ *  the outgoing split reference of each IterMark.
+ *
+ *  These out-going splits can then be used to
+ *  check if the iterators are independent.
+ */
+class IterMarkSplitCollector {
+ public:
+  // mark all IterMarks that are visited.
+  std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_;
+  // each iter mark to its outgoing splits that are referenced.
+  std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, 
ObjectPtrEqual>
+      mark2splits_;
+  /*!
+   * \brief Collect all mark2splits recursively from indices.
+   * \param indices The iterator of interest.
+   */
+  void Collect(const Array<IterSumExpr>& indices) {
+    for (IterSumExpr sum_expr : indices) {
+      for (IterSplitExpr split : sum_expr->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+
+  void CollectInternal(const IterMark& mark) {
+    if (visited_.count(mark)) return;
+    visited_.insert(mark);
+    if (auto* op = mark->source.as<IterSumExprNode>()) {
+      for (IterSplitExpr split : op->args) {
+        this->CollectInternal(split->source);
+        mark2splits_[split->source].push_back(split);
+      }
+    }
+  }
+};
+
+// Rewriter to rewrite PrimExpr to IterMapExpr
+// when possible
+class IterMapRewriter : public ExprMutator {
+ public:
+  using Parent = ExprMutator;
+
+  explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& 
input_iters)
+      : analyzer_(analyzer) {
+    for (auto kv : input_iters) {
+      const auto& vrng = kv.second;
+      if (is_zero(vrng->min)) {
+        IterMark mark(kv.first, vrng->extent);
+        var_map_[kv.first] = IterSplitExpr(mark);
+        input_marks_.push_back(mark);
+      } else {
+        IterMark mark(kv.first - vrng->min, vrng->extent);
+        auto sum_expr = ToIterSumExpr(IterSplitExpr(mark));
+        sum_expr.CopyOnWrite()->base = vrng->min;
+        var_map_[kv.first] = sum_expr;
+        input_marks_.push_back(mark);
+      }
+    }
+  }
+
+  size_t unresolved_count() const { return unresolved_count_; }
+
+  IterSumExpr Rewrite(PrimExpr expr) {
+    return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
+  }
+
+  bool CheckBijective(const Array<IterSumExpr>& indices) {
+    // This function checks two conditions:
+    // - C0: Each iter mark should be fully covered by non-overlapping splits.
+    // - C1: All of the input iterators are used.
+    //
+    // Example: given x in [0, 8) y in [0, 6)
+    // - indices = [x, x+1, y] won't pass because x and x+1 contribute
+    //   two splits that overlaps with each other.
+    // - indices = [x / 4, x % 4, y] will pass because x / 4 and x % 4
+    //   contribute two non-overlapping splits that covers x.
+    // - indices = [x / 4, x % 4] won't pass because y is not used.
+    //
+    IterMarkSplitCollector collector;
+    // We can check that for each iter mark:
+    // All the splits that refers to the itermark covers its extent.
+    // The splits do not overlap with each other.
+    collector.Collect(indices);
+    for (IterMark mark : collector.visited_) {
+      if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) 
return false;
+    }
+    // all input marks must be visited
+    for (auto mark : input_marks_) {
+      if (collector.visited_.count(mark) == 0) return false;
+    }
+    return true;
+  }
+
+  // override the original mutate function.
+  PrimExpr VisitExpr(const PrimExpr& input_expr) final {
+    auto expr = ExprMutator::VisitExpr(input_expr);
+    if (expr->IsInstance<IterMapExprNode>()) {
+      ++unresolved_count_;
+    }
+    return expr;
+  }
+
+  // Normal mutation without normalization.
+  PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); }
+
+  PrimExpr VisitExpr_(const VarNode* op) final;
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const FloorDivNode* op) final;
+  PrimExpr VisitExpr_(const FloorModNode* op) final;
+
+ private:
+  // temp hash for de-duplication purposes.
+  struct IterSumHash {
+    size_t operator()(const IterSumExpr& value) const {
+      // for now only hash on source index.
+      size_t hash = value->args.size();
+      for (size_t i = 0; i < value->args.size(); ++i) {
+        hash = support::HashCombine(hash, std::hash<const 
Object*>()(value->args[i]->source.get()));
+      }
+      return hash;
+    }
+  };
+
+  struct IterSumEqual {
+    bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const {
+      tir::ExprDeepEqual equal;
+      if (lhs->args.size() != rhs->args.size()) return false;
+      if (!equal(lhs->base, rhs->base)) return false;
+      for (size_t i = 0; i < lhs->args.size(); ++i) {
+        auto lvalue = lhs->args[i];
+        auto rvalue = lhs->args[i];
+        if (!lvalue->source.same_as(rvalue->source)) return false;
+        if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false;
+        if (!equal(lvalue->scale, rvalue->scale)) return false;
+        if (!equal(lvalue->extent, rvalue->extent)) return false;
+      }
+      return true;
+    }
+  };
+
+  // Internal analyzer
+  Analyzer* analyzer_;
+  // Counter to keep track of unresolved cases.
+  int unresolved_count_{0};
+  // The var map
+  std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
+  // input iter marks
+  std::vector<IterMark> input_marks_;
+  // The canonical map for sum
+  std::unordered_map<IterSumExpr, IterSplitExpr, IterSumHash, IterSumEqual> 
sum_fuse_map_;
+
+  /*!
+   * \brief Verify that splits fully covers mark in a non-overlapping fashion.
+   *        If verification passes, return splits from outermost to inner most 
order.
+   *        If not, return an empty array
+   * \param mark The iterator of interest.
+   * \param splits The splits to be verified.
+   * \return The normalized splits.
+   */
+  Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark,
+                                          const std::vector<IterSplitExpr>& 
splits) {
+    std::vector<bool> used(splits.size(), false);
+    std::vector<IterSplitExpr> iters;
+    PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1);
+
+    for (size_t i = 0; i < splits.size(); ++i) {
+      size_t j = 0;
+      for (; j < splits.size(); ++j) {
+        if (used[j]) continue;
+        if (!used[j] && CanProveEqual(splits[j]->lower_factor, 
expected_lower_factor)) break;
+      }
+      if (j == splits.size()) {
+        return Array<IterSplitExpr>();
+      }
+      used[j] = true;
+      iters.push_back(splits[j]);
+      expected_lower_factor *= splits[j]->extent;
+    }
+    if (!CanProveEqual(expected_lower_factor, mark->extent)) return 
Array<IterSplitExpr>();
+    return Array<IterSplitExpr>(iters.rbegin(), iters.rend());
+  }
+
+  /*!
+   * \brief Normalize expr to an iterator + offset.
+   * \param expr The input expression.
+   * \return The Normalized expression.
+   */
+  IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
+    if (expr->args.size() <= 1) return expr;
+    PrimExpr base = expr->base;
+    expr.CopyOnWrite()->base = make_zero(expr->dtype);
+    auto opt = TryFuseIters(expr);
+    expr.CopyOnWrite()->base = base;
+    if (opt) {
+      expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()});
+      return expr;
+    } else {
+      ++unresolved_count_;
+      return expr;
+    }
+  }
+
+  bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value == crhs->value;
+    return analyzer_->CanProve(lhs - rhs == 0);
+  }
+
+  /*!
+   * \brief Create a IterSumExpr from expr.
+   * \param expr The input expr.
+   * \return The transformed IterSumExpr.
+   */
+  IterSumExpr ToIterSumExpr(PrimExpr expr) {
+    if (const auto* op = expr.as<IterSumExprNode>()) {
+      return GetRef<IterSumExpr>(op);
+    } else if (const auto* op = expr.as<IterSplitExprNode>()) {
+      return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype));
+    } else {
+      CHECK(!expr->IsInstance<IterMapExprNode>());
+      return IterSumExpr({}, expr);
+    }
+  }
+
+  // Try to normalize IterSum into a fused IterMark
+  // return a corresponding splitexpr if needed.
+  Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) {
+    if (!is_zero(expr->base)) return NullOpt;
+    if (expr->args.size() == 1) return expr->args[0];
+    // select the iterators in order
+    std::vector<bool> visited(expr->args.size(), false);
+    std::vector<IterSplitExpr> iters;
+    iters.reserve(expr->args.size());
+    // canonicalize the expression
+    // check if it can be remapped into a fused pattern.
+    PrimExpr expected_scale = make_const(expr->base->dtype, 1);
+    for (size_t i = 0; i < expr->args.size(); ++i) {
+      size_t j = 0;
+      for (; j < expr->args.size(); ++j) {
+        if (!visited[j] && CanProveEqual(expr->args[j]->scale, 
expected_scale)) break;
+      }
+      if (j == expr->args.size()) {
+        return NullOpt;
+      }
+      visited[j] = true;
+      iters.push_back(expr->args[j]);
+      expected_scale *= expr->args[j]->extent;
+    }
+    // update the iterator to use the canonicalized form
+    expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), 
iters.rend());
+    auto it = sum_fuse_map_.find(expr);
+    if (it != sum_fuse_map_.end()) return it->second;
+    auto mark = IterMark(expr, expected_scale);
+    IterSplitExpr split(mark);
+    sum_fuse_map_[expr] = split;
+    return split;
+  }
+
+  bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) {
+    const auto* clhs = lhs.as<IntImmNode>();
+    const auto* crhs = rhs.as<IntImmNode>();
+    if (clhs && crhs) return clhs->value % crhs->value == 0;
+    return analyzer_->CanProve(floormod(lhs, rhs) == 0);
+  }
+
+  PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs);
+  PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs);
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
+    tir::ExprDeepEqual equal;
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, 
rhs->lower_factor) &&
+          equal(lvalue->extent, rhs->extent)) {
+        if (sign > 0) {
+          rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale;
+        } else {
+          rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale;
+        }
+        lhs->args.Set(i, rhs);
+        return;
+      }
+    }
+    if (sign > 0) {
+      lhs->args.push_back(rhs);
+    } else {
+      rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale;
+      lhs->args.push_back(rhs);
+    }
+  }
+
+  static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) {
+    for (size_t i = 0; i < rhs->args.size(); ++i) {
+      AddToLhs(lhs, rhs->args[i], sign);
+    }
+    if (sign > 0) {
+      lhs->base += rhs->base;
+    } else {
+      lhs->base -= rhs->base;
+    }
+  }
+
+  static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) {
+    for (size_t i = 0; i < lhs->args.size(); ++i) {
+      IterSplitExpr lvalue = lhs->args[i];
+      lvalue.CopyOnWrite()->scale *= rhs;
+      lhs->args.Set(i, lvalue);
+    }
+    lhs->base *= rhs;
+  }
+};
+
+Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const 
Map<Var, Range>& input_iters,
+                                 arith::Analyzer* analyzer) {
+  // Overall detection algorithm is divided into two steps:
+  // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr 
patterns.
+  // - Step1: IterIndependenceChecker checks if the iterator are independent.
+  IterMapRewriter rewriter(analyzer, input_iters);
+  Array<IterSumExpr> results;
+
+  for (PrimExpr value : indices) {
+    results.push_back(rewriter.Rewrite(value));
+    if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
+  }
+  if (!rewriter.CheckBijective(results)) return Array<IterSumExpr>();
+
+  return results;
+}
+
+TVM_REGISTER_GLOBAL("arith.DetectIterMap")
+    .set_body_typed([](const Array<PrimExpr>& indices, const Map<Var, Range>& 
input_iters) {
+      arith::Analyzer ana;
+      return DetectIterMap(indices, input_iters, &ana);
+    });
+
+PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) {
+  auto var = GetRef<Var>(op);
+  auto it = var_map_.find(var);
+  if (it != var_map_.end()) return it->second;
+  return std::move(var);
+}
+
+PrimExpr IterMapRewriter::VisitExpr_(const AddNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Parent::VisitExpr_(op);
+  }
+
+  PrimExpr a = this->DirectMutate(op->a);
+  PrimExpr b = this->DirectMutate(op->b);
+
+  // const folding
+  PrimExpr const_res = TryConstFold<Add>(a, b);
+  if (const_res.defined()) return const_res;
+  // does not contain iter map.
+  if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
+    if (op->a.same_as(a) && op->b.same_as(b)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Add(a, b);
+    }
+  }
+
+  // canonical form simplification.
+  IterSumExpr ret = ToIterSumExpr(std::move(a));
+
+  if (!b->IsInstance<IterMapExprNode>()) {
+    ret.CopyOnWrite()->base += b;
+  } else if (const auto* op = b.as<IterSumExprNode>()) {
+    AddToLhs(ret.CopyOnWrite(), GetRef<IterSumExpr>(op), 1);
+  } else if (const auto* op = b.as<IterSplitExprNode>()) {
+    AddToLhs(ret.CopyOnWrite(), GetRef<IterSplitExpr>(op), 1);
+  } else {
+    AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), 1);
+  }
+  return std::move(ret);
+}
+
+PrimExpr IterMapRewriter::VisitExpr_(const SubNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Parent::VisitExpr_(op);
+  }
+
+  PrimExpr a = this->DirectMutate(op->a);
+  PrimExpr b = this->DirectMutate(op->b);
+
+  // const folding
+  PrimExpr const_res = TryConstFold<Sub>(a, b);
+  if (const_res.defined()) return const_res;
+
+  // does not contain iter map.
+  if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
+    if (op->a.same_as(a) && op->b.same_as(b)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Sub(a, b);
+    }
+  }
+
+  // canonical form simplification.
+  IterSumExpr ret = ToIterSumExpr(std::move(a));
+
+  if (!b->IsInstance<IterMapExprNode>()) {
+    ret.CopyOnWrite()->base -= b;
+  } else if (const auto* op = b.as<IterSumExprNode>()) {
+    AddToLhs(ret.CopyOnWrite(), GetRef<IterSumExpr>(op), -1);
+  } else if (const auto* op = b.as<IterSplitExprNode>()) {
+    AddToLhs(ret.CopyOnWrite(), GetRef<IterSplitExpr>(op), -1);
+  } else {
+    AddToLhs(ret.CopyOnWrite(), ToIterSumExpr(b), -1);
+  }
+  return std::move(ret);
+}
+
+PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Parent::VisitExpr_(op);
+  }
+  // normalize
+  PrimExpr a = this->DirectMutate(op->a);
+  PrimExpr b = this->DirectMutate(op->b);
+
+  // const folding
+  PrimExpr const_res = TryConstFold<Mul>(a, b);
+  if (const_res.defined()) return const_res;
+
+  // does not contain iter map.
+  if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
+    if (op->a.same_as(a) && op->b.same_as(b)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return Mul(a, b);
+    }
+  }
+
+  if (a->IsInstance<IterMapExprNode>() && b->IsInstance<IterMapExprNode>()) {
+    // cannot multiply two iterators, mark as unresolved.
+    ++unresolved_count_;
+    return Mul(a, b);
+  }
+
+  if (!a->IsInstance<IterMapExprNode>()) {
+    std::swap(a, b);
+  }
+
+  if (a->IsInstance<IterSumExprNode>()) {
+    IterSumExpr ret = Downcast<IterSumExpr>(std::move(a));
+    MulToLhs(ret.CopyOnWrite(), b);
+    return std::move(ret);
+  } else {
+    CHECK(a->IsInstance<IterSplitExprNode>());
+    IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
+    ret.CopyOnWrite()->scale *= b;
+    return std::move(ret);
+  }
+}
+
+PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs) {
+  if (is_one(rhs)) return std::move(lhs);
+  if (!is_one(lhs->scale)) {
+    if (CanProveDivisible(lhs->scale, rhs)) {
+      lhs.CopyOnWrite()->scale = floordiv(lhs->scale, rhs);
+      return std::move(lhs);
+    } else {
+      if (CanProveDivisible(rhs, lhs->scale)) {
+        rhs = floordiv(rhs, lhs->scale);
+        lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1);
+      } else {
+        // mark as unresolved.
+        ++unresolved_count_;
+        return floordiv(lhs, rhs);
+      }
+    }
+  }
+
+  if (CanProveDivisible(lhs->extent, rhs)) {
+    auto* ptr_lhs = lhs.CopyOnWrite();
+    ptr_lhs->lower_factor *= rhs;
+    ptr_lhs->extent = analyzer_->Simplify(floordiv(ptr_lhs->extent, rhs));
+    return std::move(lhs);
+  } else {
+    // mark as unresolved.
+    ++unresolved_count_;
+    return floordiv(lhs, rhs);
+  }
+}
+
+PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Parent::VisitExpr_(op);
+  }
+
+  PrimExpr a = this->DirectMutate(op->a);
+  PrimExpr b = this->DirectMutate(op->b);
+
+  // const folding
+  PrimExpr const_res = TryConstFold<FloorDiv>(a, b);
+  if (const_res.defined()) return const_res;
+
+  // does not contain iter map.
+  if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
+    if (op->a.same_as(a) && op->b.same_as(b)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return FloorDiv(a, b);
+    }
+  }
+
+  if (b->IsInstance<IterMapExprNode>()) {
+    // cannot divide an iterator, mark as unresolved.
+    ++unresolved_count_;
+    return FloorDiv(a, b);
+  }
+
+  if (a->IsInstance<IterSumExprNode>()) {
+    IterSumExpr ret = Downcast<IterSumExpr>(std::move(a));
+    if (auto opt = TryFuseIters(ret)) {
+      return SplitFloorDivConst(opt.value(), b);
+    } else {
+      ++unresolved_count_;
+      return FloorDiv(a, b);
+    }
+  } else {
+    CHECK(a->IsInstance<IterSplitExprNode>());
+    IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
+    return SplitFloorDivConst(ret, b);
+  }
+}
+
+PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs) {
+  if (is_one(rhs)) return make_zero(lhs->dtype);
+  if (!is_one(lhs->scale)) {
+    if (CanProveDivisible(lhs->scale, rhs)) {
+      return make_zero(lhs->dtype);
+    } else {
+      if (CanProveDivisible(rhs, lhs->scale)) {
+        rhs = floormod(rhs, lhs->scale);
+      } else {
+        // mark as unresolved.
+        ++unresolved_count_;
+        return floormod(lhs, rhs);
+      }
+    }
+  }
+
+  if (CanProveDivisible(lhs->extent, rhs)) {
+    lhs.CopyOnWrite()->extent = rhs;
+    return std::move(lhs);
+  } else {
+    // mark as unresolved.
+    ++unresolved_count_;
+    return floormod(lhs, rhs);
+  }
+}
+
+PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
+  if (!IsIndexType(op->dtype)) {
+    return Parent::VisitExpr_(op);
+  }
+
+  PrimExpr a = this->DirectMutate(op->a);
+  PrimExpr b = this->DirectMutate(op->b);
+
+  // const folding
+  PrimExpr const_res = TryConstFold<FloorMod>(a, b);
+  if (const_res.defined()) return const_res;
+
+  // does not contain iter map.
+  if (!a->IsInstance<IterMapExprNode>() && !b->IsInstance<IterMapExprNode>()) {
+    if (op->a.same_as(a) && op->b.same_as(b)) {
+      return GetRef<PrimExpr>(op);
+    } else {
+      return FloorMod(a, b);
+    }
+  }
+
+  if (b->IsInstance<IterMapExprNode>()) {
+    // cannot mod an iterator, mark as unresolved.
+    ++unresolved_count_;
+    return FloorMod(a, b);
+  }
+
+  if (a->IsInstance<IterSumExprNode>()) {
+    IterSumExpr ret = Downcast<IterSumExpr>(std::move(a));
+    if (auto opt = TryFuseIters(ret)) {
+      return SplitFloorModConst(opt.value(), b);
+    } else {
+      ++unresolved_count_;
+      return FloorMod(a, b);
+    }
+  } else {
+    CHECK(a->IsInstance<IterSplitExprNode>());
+    IterSplitExpr ret = Downcast<IterSplitExpr>(std::move(a));
+    return SplitFloorModConst(ret, b);
+  }
+}
+
+}  // namespace arith
+}  // namespace tvm
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index c237edc..cb8ef01 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -882,6 +882,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
FloorModNode* op) {
     TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2),
                        c2.Eval()->value > 0 && c1.Eval()->value % 
c2.Eval()->value == 0);
 
+    TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x));
+    TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y));
+
     // try modular analysis
     if (floormod(x, c1).Match(ret)) {
       ModularSet mod = analyzer_->modular_set(x.Eval());
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index d21cb1f..1122b8e 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -28,6 +28,8 @@
 #include <algorithm>
 #include <unordered_map>
 
+#include "../support/util.h"
+
 namespace tvm {
 
 // Define the dispatch functio here since primary user is in this file.
@@ -163,7 +165,7 @@ class VarCountingSHashHandler : public 
SHashReducer::Handler {
     // combine in the reverse order of the stack.
     size_t reduced_hash = task.reduced_hash;
     for (size_t i = result_stack_.size(); i != stack_begin; --i) {
-      reduced_hash = HashCombine(reduced_hash, result_stack_[i - 1]);
+      reduced_hash = support::HashCombine(reduced_hash, result_stack_[i - 1]);
     }
     result_stack_.resize(stack_begin);
     return reduced_hash;
@@ -186,8 +188,8 @@ class VarCountingSHashHandler : public 
SHashReducer::Handler {
           // Append the graph node counter to the hash
           // so that we can distinguish DAG from trees.
           if (entry.graph_node_hash) {
-            entry.reduced_hash =
-                HashCombine(entry.reduced_hash, 
std::hash<size_t>()(graph_node_counter_++));
+            entry.reduced_hash = support::HashCombine(entry.reduced_hash,
+                                                      
std::hash<size_t>()(graph_node_counter_++));
           }
           hash_memo_[entry.object] = entry.reduced_hash;
         }
@@ -229,16 +231,6 @@ class VarCountingSHashHandler : public 
SHashReducer::Handler {
     vtable_->SHashReduce(object.get(), SHashReducer(this, map_free_vars));
   }
 
-  /*!
-   * \brief Combine two hash values into a single one.
-   * \param key The left operand.
-   * \param value The right operand.
-   * \return the combined result.
-   */
-  size_t HashCombine(size_t key, size_t value) {
-    return key ^ (value + 0x9e3779b9 + (key << 6) + (key >> 2));
-  }
-
  private:
   // free var counter.
   size_t free_var_counter_{0};
diff --git a/src/support/util.h b/src/support/util.h
index 859b372..5020df2 100644
--- a/src/support/util.h
+++ b/src/support/util.h
@@ -152,6 +152,16 @@ inline int Execute(std::string cmd, std::string* err_msg) {
   return 255;
 }
 
+/*!
+ * \brief Combine two hash values into a single one.
+ * \param key The left operand.
+ * \param value The right operand.
+ * \return the combined result.
+ */
+inline size_t HashCombine(size_t key, size_t value) {
+  return key ^ (value + 0x9e3779b9 + (key << 6) + (key >> 2));
+}
+
 }  // namespace support
 }  // namespace tvm
 #endif  // TVM_SUPPORT_UTIL_H_
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py 
b/tests/python/unittest/test_arith_iter_affine_map.py
new file mode 100644
index 0000000..9fb0988
--- /dev/null
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -0,0 +1,176 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import te
+
+
+def ifuse(inputs):
+    """Fuse iterators"""
+    value, extent = 0, 1
+    for i, ext in inputs:
+        value = value * ext + i
+        extent = extent * ext
+    return (value, extent)
+
+
+def isplit(axis, factor):
+    """Split iterators"""
+    fld = tvm.tir.floordiv
+    flm = tvm.tir.floormod
+    return [
+        (fld(axis[0], factor), fld(axis[1] + (factor - 1), factor)),
+        (flm(axis[0], factor), factor),
+    ]
+
+
+def var_dom(iters):
+    """Get domains of iterators"""
+    return {var: tvm.ir.Range(0, ext) for var, ext in iters}
+
+
+def assert_iter_sum_pattern(sum_expr, extent, base, scale=1):
+    """Check the sum expr have the right pattern."""
+    assert isinstance(sum_expr, tvm.arith.IterSumExpr)
+    if extent == 1:
+        assert len(sum_expr.args) == 0
+    else:
+        assert len(sum_expr.args) == 1
+        tvm.testing.assert_prim_expr_equal(sum_expr.args[0].extent, extent)
+        tvm.testing.assert_prim_expr_equal(sum_expr.args[0].scale, scale)
+    tvm.testing.assert_prim_expr_equal(sum_expr.base, base)
+
+
+def test_trivial():
+    x = tvm.tir.Var("x", "int32"), 3
+    y = tvm.tir.Var("y", "int32"), 4
+
+    res = tvm.arith.detect_iter_map([x[0], y[0], 3], var_dom([x, y]))
+
+    assert len(res) == 3
+    assert_iter_sum_pattern(res[0], 3, 0)
+    assert_iter_sum_pattern(res[1], 4, 0)
+    assert_iter_sum_pattern(res[2], 1, 3)
+
+    res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y]))
+    assert len(res) == 0
+
+    # not independent
+    res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y]))
+    assert len(res) == 0
+
+
+def test_fuse():
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    c = tvm.tir.SizeVar("c", "int32")
+
+    res = tvm.arith.detect_iter_map([y * 3 + 1 + c + x], var_dom([(x, 3), (y, 
4)]))
+    assert len(res) == 1
+    assert_iter_sum_pattern(res[0], 12, 1 + c)
+
+    res = tvm.arith.detect_iter_map([ifuse([(x, 3), (y, 4)])[0]], var_dom([(x, 
3), (y, 4)]))
+    assert len(res) == 1
+    assert_iter_sum_pattern(res[0], 12, 0)
+
+    # fuse with symbolic factor
+    res = tvm.arith.detect_iter_map([(y + 1) * c + x], var_dom([(x, c), (y, 
4)]))
+    assert len(res) == 1
+    assert_iter_sum_pattern(res[0], 4 * c, c)
+
+    # duplication
+    res = tvm.arith.detect_iter_map([y * 3 + x, y], var_dom([(x, 3), (y, 4)]))
+    assert len(res) == 0
+
+    # duplication 2
+    res = tvm.arith.detect_iter_map([y, x + 1, y], var_dom([(x, 3), (y, 4)]))
+    assert len(res) == 0
+
+    # factor mismatch
+    res = tvm.arith.detect_iter_map([y * 4 + x], var_dom([(x, 3), (y, 4)]))
+    assert len(res) == 0
+
+
+def test_split():
+    x = tvm.tir.Var("x", "int32")
+    y = tvm.tir.Var("y", "int32")
+    z = tvm.tir.Var("y", "int32")
+    c0 = tvm.tir.SizeVar("c0", "int32")
+    c1 = tvm.tir.SizeVar("c1", "int32")
+    c2 = tvm.tir.SizeVar("c1", "int32")
+    fld = tvm.tir.floordiv
+    flm = tvm.tir.floormod
+
+    res = tvm.arith.detect_iter_map([fld(x, 3), flm(x, 3) * 2 + c1], 
var_dom([(x, 24)]))
+
+    assert len(res) == 2
+    assert_iter_sum_pattern(res[0], 8, 0)
+    assert_iter_sum_pattern(res[1], 3, c1, 2)
+
+    res = tvm.arith.detect_iter_map([fld(x, 6), fld(flm(x, 6), 2), flm(x, 2)], 
var_dom([(x, 24)]))
+
+    assert len(res) == 3
+    assert_iter_sum_pattern(res[0], 4, 0)
+    assert_iter_sum_pattern(res[1], 3, 0)
+    assert_iter_sum_pattern(res[2], 2, 0)
+
+    # simple symbolic bound
+    # TODO(tvm-team) improve symbolic divisible check to enable
+    # more complicated symbolic bound
+    res = tvm.arith.detect_iter_map([fld(x, c0), flm(x, c0)], var_dom([(x, c1 
* c0)]))
+
+    assert len(res) == 2
+    assert_iter_sum_pattern(res[0], c1, 0)
+    assert_iter_sum_pattern(res[1], c0, 0)
+
+
+def test_compound():
+    x = tvm.tir.Var("x", "int32"), 10
+    y = tvm.tir.Var("y", "int32"), 9
+
+    xo, xi = isplit(x, 5)
+    yo, yi = isplit(y, 3)
+    z = ifuse([yo, xo, yi])
+
+    res = tvm.arith.detect_iter_map([z[0], xi[0]], var_dom([x, y]))
+
+    assert len(res) == 2
+    assert_iter_sum_pattern(res[0], 18, 0)
+    assert_iter_sum_pattern(res[1], 5, 0)
+    # reconstruct the pattern manually
+    mx = tvm.arith.IterMark(x[0], 10)
+    my = tvm.arith.IterMark(y[0], 9)
+
+    xoscale = 3
+    xiscale = 1
+    yoscale = 6
+    yiscale = 1
+    mxo = tvm.arith.IterSplitExpr(mx, 5, 2, xoscale)
+    mxi = tvm.arith.IterSplitExpr(mx, 1, 5, xiscale)
+    myo = tvm.arith.IterSplitExpr(my, 3, 3, yoscale)
+    myi = tvm.arith.IterSplitExpr(my, 1, 3, yiscale)
+
+    mz = tvm.arith.IterMark(tvm.arith.IterSumExpr([myo, mxo, myi], 0), 18)
+    sz = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(mz, 1, 18, 1)], 0)
+    tvm.ir.assert_structural_equal(sz, res[0])
+
+
+if __name__ == "__main__":
+    test_split()
+    test_trivial()
+    test_fuse()
+    test_compound()

Reply via email to