This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e32d47e [Arith] Inverse affine map (#8384)
e32d47e is described below
commit e32d47e9f5fca5772e027a9595620436285e295d
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Jul 3 20:59:16 2021 -0400
[Arith] Inverse affine map (#8384)
* [Arith] Inverse affine map
* [Arith] Inverse affine map
* Update iter_affine_map.h
* Update iter_affine_map.h
* Update iter_affine_map.py
* Topology order visit
* doc
* fix
* address comments
* lint
* remove print
---
include/tvm/arith/iter_affine_map.h | 21 +++
python/tvm/arith/__init__.py | 7 +-
python/tvm/arith/iter_affine_map.py | 27 ++++
src/arith/iter_affine_map.cc | 142 +++++++++++++++++++++
.../python/unittest/test_arith_iter_affine_map.py | 61 +++++++++
5 files changed, 257 insertions(+), 1 deletion(-)
diff --git a/include/tvm/arith/iter_affine_map.h
b/include/tvm/arith/iter_affine_map.h
index 641d0e0..d671339 100644
--- a/include/tvm/arith/iter_affine_map.h
+++ b/include/tvm/arith/iter_affine_map.h
@@ -284,6 +284,27 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>&
indices, const Map<Var,
arith::Analyzer* analyzer);
/*!
+ * \brief Apply the inverse of the affine transformation to the outputs.
+ *
+ * Similar to the back-propagation, starting from the outputs, it visits the
DAG of the expressions
+ * in reverse topology order and applies the inverse of the affine
transformation until it reaches
+ * the input. The affine iter map is required to be bijective.
+ *
+ * For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1],
+ * the affine transformation specified by `iter_map` will be applied to
`outputs` and the result
+ * will be {l0: ((output_0*16) + output_1)}.
+ *
+ * \sa DetectIterMap
+ *
+ * \param iter_map The bijective affine iter map.
+ * \param outputs The outputs of the affine transformation.
+ *
+ * \return The map from the input to the transformed result.
+ */
+Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
+ const Array<PrimExpr> outputs);
+
+/*!
* \brief Detect if bindings can be written as
* [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
*
diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py
index d1e4431..f5a0478 100644
--- a/python/tvm/arith/__init__.py
+++ b/python/tvm/arith/__init__.py
@@ -22,4 +22,9 @@ from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations, solve_linear_inequalities
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
-from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr,
subspace_divide
+from .iter_affine_map import (
+ detect_iter_map,
+ normalize_iter_map_to_expr,
+ subspace_divide,
+ inverse_affine_iter_map,
+)
diff --git a/python/tvm/arith/iter_affine_map.py
b/python/tvm/arith/iter_affine_map.py
index bfd5dfa..85513ec 100644
--- a/python/tvm/arith/iter_affine_map.py
+++ b/python/tvm/arith/iter_affine_map.py
@@ -173,3 +173,30 @@ def subspace_divide(bindings, input_iters, sub_iters,
predicate=True, require_bi
Empty array if no match can be found.
"""
return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters,
predicate, require_bijective)
+
+
+def inverse_affine_iter_map(iter_map, outputs):
+ """Apply the inverse of the affine transformation to the outputs.
+ Similar to the back-propagation, starting from the outputs, it visits the
DAG of the expressions
+ in reverse topology order and applies the inverse of the affine
transformation until it reaches
+ the input. The affine iter map is required to be bijective.
+
+ For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0,
output_1],
+ the affine transformation specified by `iter_map` will be applied to
`outputs` and the result
+ will be {l0: ((output_0*16) + output_1)}.
+
+ See also :any:`detect_iter_map`.
+
+ Parameters
+ ----------
+ iter_map : List[IterSumExpr]
+ The bijective affine iter map.
+ outputs : List[PrimExpr]
+ The outputs of the affine transformation.
+
+ Returns
+ -------
+ results : Map[Var, PrimExpr]
+ The map from the input to the transformed result.
+ """
+ return _ffi_api.InverseAffineIterMap(iter_map, outputs)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index c1daae9..e885195 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -1385,5 +1385,147 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
return SubspaceDivide(bindings, root_iters, sub_iters, predicate,
require_bijective, &ana);
});
+class InverseAffineIterMapTransformer {
+ public:
+ explicit InverseAffineIterMapTransformer(Analyzer* analyzer) :
analyzer_(analyzer) {}
+
+ Map<Var, PrimExpr> operator()(const Array<IterSumExpr>& iter_map,
+ const Array<PrimExpr>& outputs) {
+ ICHECK(iter_map.size() == outputs.size());
+ std::vector<const IterMapExprNode*> post_dfs_order =
ReverseTopologyOrder(iter_map);
+
+ // initialize back propagation accumulator
+ for (const IterMapExprNode* node : post_dfs_order) {
+ backprop_.Set(GetRef<IterMapExpr>(node), Integer(0));
+ }
+ for (size_t i = 0; i < iter_map.size(); i++) {
+ backprop_.Set(iter_map[i], outputs[i]);
+ }
+
+ // run back propagation
+ for (const IterMapExprNode* node : post_dfs_order) {
+ if (node->IsInstance<IterSumExprNode>()) {
+ Visit_(Downcast<IterSumExpr>(GetRef<IterMapExpr>(node)));
+ } else {
+ ICHECK(node->IsInstance<IterSplitExprNode>());
+ Visit_(Downcast<IterSplitExpr>(GetRef<IterMapExpr>(node)));
+ }
+ }
+ return std::move(inverse_);
+ }
+
+ private:
+ void Visit_(const IterSumExpr& iter_map_expr) {
+ PrimExpr input = backprop_.at(iter_map_expr) - iter_map_expr->base;
+
+ // Case 1: Propagate to the input node directly when the sum expression
has only one components
+ if (iter_map_expr->args.size() == 1) {
+ const auto& source = iter_map_expr->args[0];
+ backprop_.Set(source, backprop_.at(source) + input);
+ return;
+ }
+
+ // Case 2: If the sum expression has multiple components, match the fuse
pattern and then split
+ // the sum expression for each components.
+ // For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)],
fusing i1 and i2
+ // we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we
need to split the
+ // propagated value to get the corresponding components of i1 and i2,
which are
+ // floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively.
+ Array<IterSplitExpr> splits = MatchFusePattern(iter_map_expr);
+ ICHECK(!splits.empty());
+
+ for (const IterSplitExpr& split : splits) {
+ backprop_.Set(split,
+ backprop_.at(split) + floormod(floordiv(input,
split->scale), split->extent));
+ }
+ }
+
+ std::vector<const IterMapExprNode*> ReverseTopologyOrder(const
Array<IterSumExpr>& iter_map) {
+ std::vector<const IterMapExprNode*> post_dfs_order;
+ std::unordered_map<IterMapExpr, bool, ObjectPtrHash, ObjectPtrEqual>
visited;
+
+ std::function<void(const IterMapExpr&)> fvisit = [&](const IterMapExpr&
expr) {
+ if (visited[expr]) {
+ return;
+ }
+ visited[expr] = true;
+ if (const auto* sum_expr = expr.as<IterSumExprNode>()) {
+ for (const IterSplitExpr& child : sum_expr->args) {
+ fvisit(child);
+ }
+ } else {
+ const auto* split_expr = expr.as<IterSplitExprNode>();
+ ICHECK(split_expr);
+ if (const auto* source =
split_expr->source->source.as<IterMapExprNode>()) {
+ fvisit(GetRef<IterMapExpr>(source));
+ }
+ }
+ post_dfs_order.push_back(expr.get());
+ };
+ for (const IterSumExpr& expr : iter_map) {
+ fvisit(expr);
+ }
+ std::reverse(post_dfs_order.begin(), post_dfs_order.end());
+ return post_dfs_order;
+ }
+
+ void Visit_(const IterSplitExpr& iter_map_expr) {
+ PrimExpr input = backprop_.at(iter_map_expr) * iter_map_expr->lower_factor;
+ const IterMark& source = iter_map_expr->source;
+ if (source->source.as<IterSumExprNode>()) {
+ IterSumExpr source_expr = Downcast<IterSumExpr>(source->source);
+ backprop_.Set(source_expr, backprop_.at(source_expr) + input);
+ } else {
+ Var source_var = Downcast<Var>(source->source);
+ if (inverse_.count(source_var)) {
+ inverse_.Set(source_var, inverse_.at(source_var) + input);
+ } else {
+ inverse_.Set(source_var, input);
+ }
+ }
+ }
+
+ Array<IterSplitExpr> MatchFusePattern(const IterSumExpr sum_expr) {
+ IntImm base_scale(nullptr);
+ size_t base_index = 0;
+ for (size_t i = 0; i < sum_expr->args.size(); ++i) {
+ if (const auto* op = sum_expr->args[i]->scale.as<IntImmNode>()) {
+ if (!base_scale.defined() || op->value < base_scale->value) {
+ base_scale = GetRef<IntImm>(op);
+ base_index = i;
+ }
+ }
+ }
+ ICHECK(base_scale.defined());
+ std::vector<IterSplitExpr> iters;
+ std::vector<bool> visited(sum_expr->args.size(), false);
+ PrimExpr expected_scale = base_scale;
+ for (size_t i = 0; i < sum_expr->args.size(); i++) {
+ size_t j = i == 0 ? base_index : 0;
+ for (; j < sum_expr->args.size(); ++j) {
+ if (!visited[j] && analyzer_->CanProveEqual(sum_expr->args[j]->scale,
expected_scale))
+ break;
+ }
+ ICHECK(j != sum_expr->args.size());
+ visited[j] = true;
+ iters.push_back(sum_expr->args[j]);
+ expected_scale *= sum_expr->args[j]->extent;
+ }
+ return iters;
+ }
+
+ Analyzer* analyzer_;
+ Map<IterMapExpr, PrimExpr> backprop_; // the accumulator of backpropgation
+ Map<Var, PrimExpr> inverse_; // the result of inverse
transformation
+};
+
+Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
+ const Array<PrimExpr> outputs) {
+ Analyzer analyzer;
+ return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs);
+}
+
+TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap);
+
} // namespace arith
} // namespace tvm
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 7bfdfc6..b34acb9 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -643,6 +643,66 @@ def test_normalize_iter_map_to_expr():
tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]),
flm(x[0], 5))
+def test_inverse_affine_iter_map():
+ analyzer = tvm.arith.Analyzer()
+ l0 = create_iter("l0", 64)
+ l1 = create_iter("l1", 64)
+ l2 = create_iter("l2", 64)
+
+ # simple case
+ l0_0, l0_1 = isplit(l0, 16)
+ l1_0, l1_1 = isplit(l1, 4)
+ l0_1_l1_1_fused = ifuse([l0_1, l1_1])
+
+ iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0],
l1_0[0]], var_dom([l0, l1]))
+ outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in
range(len(iter_map))]
+ res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
+ assert len(res) == 2
+ l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16
+ l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4
+ assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
+ assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
+
+ # compound case
+ l0_0, l0_1 = isplit(l0, 16)
+ l1_0, l1_1 = isplit(l1, 4)
+ l2_1, l2_2 = isplit(l2, 4)
+ l2_0, l2_1 = isplit(l2_1, 4)
+
+ l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0])
+
+ iter_map = tvm.arith.detect_iter_map(
+ [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]],
var_dom([l0, l1, l2])
+ )
+ outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in
range(len(iter_map))]
+ res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
+ assert len(res) == 3
+ l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16
+ l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4
+ l2_inverse = (
+ floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) *
4 + outputs[2]
+ )
+
+ assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
+ assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0
+ assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0
+
+ # diamond-shape DAG
+ l0_0, l0_1 = isplit(l0, 16)
+ l1 = ifuse([l0_1, l0_0])
+ l1_0, l1_1 = isplit(l1, 8)
+ l2 = ifuse([l1_1, l1_0])
+
+ iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0]))
+ outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in
range(len(iter_map))]
+ res = tvm.arith.inverse_affine_iter_map(iter_map, outputs)
+ assert len(res) == 1
+ l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0],
8), 8)
+ l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse,
4), 16)
+
+ assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0
+
+
if __name__ == "__main__":
test_split()
test_trivial()
@@ -652,3 +712,4 @@ if __name__ == "__main__":
test_normalize_iter_map_to_expr()
test_subspace_division()
test_complex()
+ test_inverse_affine_iter_map()