tqchen commented on a change in pull request #6667: URL: https://github.com/apache/incubator-tvm/pull/6667#discussion_r504256910
########## File path: src/arith/iter_affine_map.cc ########## @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/arith/iter_affine_map.cc + */ +#include <tvm/arith/analyzer.h> +#include <tvm/arith/iter_affine_map.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> +#include <tvm/tir/expr_functor.h> +#include <tvm/tir/op.h> + +#include "../support/util.h" +#include "const_fold.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +IterMark::IterMark(PrimExpr source, PrimExpr extent) { + auto n = make_object<IterMarkNode>(); + n->source = std::move(source); + n->extent = std::move(extent); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("arith.IterMark").set_body_typed([](PrimExpr source, PrimExpr extent) { + return IterMark(source, extent); +}); + +TVM_REGISTER_NODE_TYPE(IterMarkNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<IterMarkNode>([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast<const IterMarkNode*>(node.get()); + p->stream << "IterMark(" << op->source << ", extent=" << op->extent; + }); + +IterSplitExpr::IterSplitExpr(IterMark source) { + auto n = make_object<IterSplitExprNode>(); + auto one = make_const(source->source->dtype, 1); + n->dtype = source->source->dtype; + n->source = std::move(source); + n->extent = n->source->extent; + n->lower_factor = one; + n->scale = one; + data_ = std::move(n); +} + +IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, + PrimExpr scale) { + auto n = make_object<IterSplitExprNode>(); + n->dtype = source->source->dtype; + n->source = std::move(source); + n->lower_factor = std::move(lower_factor); + n->extent = std::move(extent); + n->scale = std::move(scale); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("arith.IterSplitExpr") + .set_body_typed([](IterMark source, PrimExpr lower_factor, PrimExpr extent, PrimExpr scale) { + return IterSplitExpr(source, lower_factor, extent, scale); + }); + +TVM_REGISTER_NODE_TYPE(IterSplitExprNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<IterSplitExprNode>([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast<const IterSplitExprNode*>(node.get()); + p->stream << "IterSplit(" << op->source << ", lower_factor=" << op->lower_factor + << ", extent=" << op->extent << ", scale=" << op->scale; + }); + +IterSumExpr::IterSumExpr(Array<IterSplitExpr> args, PrimExpr base) { + auto n = make_object<IterSumExprNode>(); + n->dtype = base->dtype; + n->args = std::move(args); + n->base = std::move(base); + data_ = std::move(n); +} + +TVM_REGISTER_GLOBAL("arith.IterSumExpr") + .set_body_typed([](Array<IterSplitExpr> args, PrimExpr base) { + return IterSumExpr(args, base); + }); + +TVM_REGISTER_NODE_TYPE(IterSumExprNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch<IterSumExprNode>([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast<const IterSumExprNode*>(node.get()); + p->stream << "IterSum(" << op->args << ", " << op->base << ")"; + }); + +/*! + * \brief Util to check if all splits in the sumexpr are + * independent and complete (covers all the original iter space). + * + */ +class IterMarkSplitCollector { + public: + // mark all IterMarks that are visited. + std::unordered_set<IterMark, ObjectPtrHash, ObjectPtrEqual> visited_; + // each iter mark to its outgoing splits that are referenced. + std::unordered_map<IterMark, std::vector<IterSplitExpr>, ObjectPtrHash, ObjectPtrEqual> + mark2splits_; + /*! + * \brief Collect all mark2splits recursively from indices. + * \param indices The iterator of interest. + */ + void Collect(const Array<IterSumExpr>& indices) { + for (IterSumExpr sum_expr : indices) { + for (IterSplitExpr split : sum_expr->args) { + this->CollectInternal(split->source); + mark2splits_[split->source].push_back(split); + } + } + } + + void CollectInternal(const IterMark& mark) { + if (visited_.count(mark)) return; + visited_.insert(mark); + if (auto* op = mark->source.as<IterSumExprNode>()) { + for (IterSplitExpr split : op->args) { + this->CollectInternal(split->source); + mark2splits_[split->source].push_back(split); + } + } + } +}; + +// Rewriter to rewrite oinformations in iter +class IterMapRewriter : public ExprMutator { + public: + using Parent = ExprMutator; + + explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters) + : analyzer_(analyzer) { + for (auto kv : input_iters) { + const auto& vrng = kv.second; + if (is_zero(vrng->min)) { + IterMark mark(kv.first, vrng->extent); + var_map_[kv.first] = IterSplitExpr(mark); + input_marks_.push_back(mark); + } else { + IterMark mark(kv.first - vrng->min, vrng->extent); + auto sum_expr = ToIterSumExpr(IterSplitExpr(mark)); + sum_expr.CopyOnWrite()->base = vrng->min; + var_map_[kv.first] = sum_expr; + input_marks_.push_back(mark); + } + } + } + + size_t unresolved_count() const { return unresolved_count_; } + + IterSumExpr Rewrite(PrimExpr expr) { + return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr))); + } + + bool CheckBijective(const Array<IterSumExpr>& indices) { + IterMarkSplitCollector collector; + // We can check that for each iter mark: + // All the splits that refers to the itermark covers its extent. + // The splits do not overlap with each other. + collector.Collect(indices); + for (IterMark mark : collector.visited_) { + if (TryNormalizeSplits(mark, collector.mark2splits_[mark]).size() == 0) return false; + } + // all input marks must be visited + for (auto mark : input_marks_) { + if (collector.visited_.count(mark) == 0) return false; + } + return true; + } + + // override the original mutate function. + PrimExpr VisitExpr(const PrimExpr& input_expr) final { + auto expr = ExprMutator::VisitExpr(input_expr); + if (expr->IsInstance<IterMapExprNode>()) { + ++unresolved_count_; + } + return expr; + } + + // Normal mutation without normalization. + PrimExpr DirectMutate(PrimExpr expr) { return ExprMutator::VisitExpr(expr); } + + PrimExpr VisitExpr_(const VarNode* op) final; + PrimExpr VisitExpr_(const AddNode* op) final; + PrimExpr VisitExpr_(const SubNode* op) final; + PrimExpr VisitExpr_(const MulNode* op) final; + PrimExpr VisitExpr_(const FloorDivNode* op) final; + PrimExpr VisitExpr_(const FloorModNode* op) final; + + private: + // temp hash for de-duplication purposes. + struct IterSumHash { + size_t operator()(const IterSumExpr& value) const { + // for now only hash on source index. + size_t hash = value->args.size(); + for (size_t i = 0; i < value->args.size(); ++i) { + hash = support::HashCombine(hash, std::hash<const Object*>()(value->args[i]->source.get())); + } + return hash; + } + }; + + struct IterSumEqual { + bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const { + tir::ExprDeepEqual equal; + if (lhs->args.size() != rhs->args.size()) return false; + if (!equal(lhs->base, rhs->base)) return false; + for (size_t i = 0; i < lhs->args.size(); ++i) { + auto lvalue = lhs->args[i]; + auto rvalue = lhs->args[i]; + if (!lvalue->source.same_as(rvalue->source)) return false; + if (!equal(lvalue->lower_factor, rvalue->lower_factor)) return false; + if (!equal(lvalue->scale, rvalue->scale)) return false; + if (!equal(lvalue->extent, rvalue->extent)) return false; + } + return true; + } + }; + + // Internal analyzer + Analyzer* analyzer_; + // Counter to keep track of unresolved cases. + int unresolved_count_{0}; + // The var map + std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_; + // input iter marks + std::vector<IterMark> input_marks_; + // The canonical map for sum + std::unordered_map<IterSumExpr, IterSplitExpr, IterSumHash, IterSumEqual> sum_fuse_map_; + + /*! + * \brief Verify that splits fully covers mark in a non-overlapping fashion. + * If verification passes, return splits from outermost to inner most order. + * If not, return an empty array + * \param mark The iterator of interest. + * \param splits The splits to be verified. + * \return The normalized splits. + */ + Array<IterSplitExpr> TryNormalizeSplits(const IterMark& mark, + const std::vector<IterSplitExpr>& splits) { + std::vector<bool> used(splits.size(), false); + std::vector<IterSplitExpr> iters; + PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); + + for (size_t i = 0; i < splits.size(); ++i) { + size_t j = 0; + for (; j < splits.size(); ++j) { + if (used[j]) continue; + if (!used[j] && CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) break; + } + if (j == splits.size()) { + return Array<IterSplitExpr>(); + } + used[j] = true; + iters.push_back(splits[j]); + expected_lower_factor *= splits[j]->extent; + } + if (!CanProveEqual(expected_lower_factor, mark->extent)) return Array<IterSplitExpr>(); + return Array<IterSplitExpr>(iters.rbegin(), iters.rend()); + } + + /*! + * \brief Normalize expr to an iterator + offset. + * \param expr The input expression. + * \return The Normalized expression. + */ + IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { + if (expr->args.size() <= 1) return expr; + PrimExpr base = expr->base; + expr.CopyOnWrite()->base = make_zero(expr->dtype); + auto opt = TryFuseIters(expr); + expr.CopyOnWrite()->base = base; + if (opt) { + expr.CopyOnWrite()->args = Array<IterSplitExpr>({opt.value()}); + return expr; + } else { + ++unresolved_count_; + return expr; + } + } + + bool CanProveEqual(PrimExpr lhs, PrimExpr rhs) { + const auto* clhs = lhs.as<IntImmNode>(); + const auto* crhs = rhs.as<IntImmNode>(); + if (clhs && crhs) return clhs->value == crhs->value; + return analyzer_->CanProve(lhs - rhs == 0); + } + + /*! + * \brief Create a IterSumExpr from expr. + * \param expr The input expr. + * \return The transformed IterSumExpr. + */ + IterSumExpr ToIterSumExpr(PrimExpr expr) { + if (const auto* op = expr.as<IterSumExprNode>()) { + return GetRef<IterSumExpr>(op); + } else if (const auto* op = expr.as<IterSplitExprNode>()) { + return IterSumExpr({GetRef<IterSplitExpr>(op)}, make_zero(expr->dtype)); + } else { + CHECK(!expr->IsInstance<IterMapExprNode>()); + return IterSumExpr({}, expr); + } + } + + // Try to normalize IterSum into a fused IterMark + // return a corresponding splitexpr if needed. + Optional<IterSplitExpr> TryFuseIters(IterSumExpr expr) { + if (!is_zero(expr->base)) return NullOpt; + if (expr->args.size() == 1) return expr->args[0]; + // select the iterators in order + std::vector<bool> visited(expr->args.size(), false); + std::vector<IterSplitExpr> iters; + iters.reserve(expr->args.size()); + // canonicalize the expression + // check if it can be remapped into a fused pattern. + PrimExpr expected_scale = make_const(expr->base->dtype, 1); + for (size_t i = 0; i < expr->args.size(); ++i) { + size_t j = 0; + for (; j < expr->args.size(); ++j) { + if (!visited[j] && CanProveEqual(expr->args[j]->scale, expected_scale)) break; + } + if (j == expr->args.size()) { + return NullOpt; + } + visited[j] = true; + iters.push_back(expr->args[j]); + expected_scale *= expr->args[j]->extent; + } + // update the iterator to use the canonicalized form + expr.CopyOnWrite()->args = Array<IterSplitExpr>(iters.rbegin(), iters.rend()); + auto it = sum_fuse_map_.find(expr); + if (it != sum_fuse_map_.end()) return it->second; + auto mark = IterMark(expr, expected_scale); + IterSplitExpr split(mark); + sum_fuse_map_[expr] = split; + return split; + } + + bool CanProveDivisible(PrimExpr lhs, PrimExpr rhs) { + const auto* clhs = lhs.as<IntImmNode>(); + const auto* crhs = rhs.as<IntImmNode>(); + if (clhs && crhs) return clhs->value % crhs->value == 0; + return analyzer_->CanProve(floormod(lhs, rhs) == 0); + } + + PrimExpr SplitFloorDivConst(IterSplitExpr lhs, PrimExpr rhs); + PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr rhs); + + static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) { + tir::ExprDeepEqual equal; + for (size_t i = 0; i < lhs->args.size(); ++i) { + IterSplitExpr lvalue = lhs->args[i]; + if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) && + equal(lvalue->extent, rhs->extent)) { + if (sign > 0) { + rhs.CopyOnWrite()->scale = lvalue->scale + rhs->scale; + } else { + rhs.CopyOnWrite()->scale = lvalue->scale - rhs->scale; + } + lhs->args.Set(i, rhs); + return; + } + } + if (sign > 0) { + lhs->args.push_back(rhs); + } else { + rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale; + lhs->args.push_back(rhs); + } + } + + static void AddToLhs(IterSumExprNode* lhs, IterSumExpr rhs, int sign) { + for (size_t i = 0; i < rhs->args.size(); ++i) { + AddToLhs(lhs, rhs->args[i], sign); + } + if (sign > 0) { + lhs->base += rhs->base; + } else { + lhs->base -= rhs->base; + } + } + + static void MulToLhs(IterSumExprNode* lhs, PrimExpr rhs) { + for (size_t i = 0; i < lhs->args.size(); ++i) { + IterSplitExpr lvalue = lhs->args[i]; + lvalue.CopyOnWrite()->scale *= rhs; + lhs->args.Set(i, lvalue); + } + lhs->base *= rhs; + } +}; + +Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters, + arith::Analyzer* analyzer) { + // Overall detection algorithm is divided into two steps: + // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. + // - Step1: IterIndependenceChecker checks if the iterator are independent. + IterMapRewriter rewriter(analyzer, input_iters); + Array<IterSumExpr> results; + + for (PrimExpr value : indices) { + results.push_back(rewriter.Rewrite(value)); + if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>(); + } + if (!rewriter.CheckBijective(results)) return Array<IterSumExpr>(); Review comment: It also checks all the intermediate marks(including the input) are being covered without overlapping ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org