csullivan commented on code in PR #12863: URL: https://github.com/apache/tvm/pull/12863#discussion_r988193162
########## include/tvm/arith/analyzer.h: ########## @@ -275,6 +275,36 @@ class RewriteSimplifier { */ std::function<void()> EnterConstraint(const PrimExpr& constraint); + /*! \brief Flags to enable more computationally-intensive simplifications + * + * These simplifications may be required for specific schedules, but + * would impose too high a compile-time cost to enable by default. + * They can be enabled on an as-needed basis by calling + * `RewriteSimplifier::SetEnabledFeatures` prior to using + * `RewriteSimplifier::operator()`. + */ + enum Feature { + // No features enabled + kNone = 0, + + /* When simplifying an inequality, attempt to use scope-based knowns. + * + * Example: + * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false) + */ + kTransitivelyProveInequalities = (1 << 0), Review Comment: Powers of two for the ability to combine features I assume; do we expect additional entries in the future? The bitwise shift implicitly indicates this. If so a comment to demystify could be nice. ########## tests/python/unittest/test_tir_transform_simplify.py: ########## @@ -138,6 +138,20 @@ def sls(n, d): class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): Review Comment: I'll assume the runtime isn't significantly altered by enabling transitively_prove_inequalities for existing tests in addition to those you're adding here. If this wasn't intentional feel free to add a new base class. ########## tests/python/unittest/test_tir_transform_simplify.py: ########## @@ -547,5 +561,129 @@ def before(A: T.Buffer[16, "float32"]): expected = before +class TestRemoveTransitivelyProvableCondition(BaseBeforeAfter): + """Remove comparisons that may be proven using multiple others + + For example, the `0 < i` and `i <= j` conditions can be used to prove + that `0 < j`. + """ + + i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"] + zero = tvm.tir.IntImm("int32", 0) + + test_case = tvm.testing.parameter( + (tvm.tir.all(zero < i, i <= j), zero < j, True), + # Transitive comparisons from LT + (tvm.tir.all(i < j, j < k), i < k, True), + (tvm.tir.all(i < j, j == k), i < k, True), + (tvm.tir.all(i < j, j <= k), i < k, True), + (tvm.tir.all(i < j, j > k), i < k, False), + (tvm.tir.all(i < j, j >= k), i < k, False), + (tvm.tir.all(i < j, j != k), i < k, False), + # Transitive comparisons from LE + (tvm.tir.all(i <= j, j < k), i < k, True), + (tvm.tir.all(i <= j, j == k), i == k, False), + (tvm.tir.all(i <= j, j == k), i <= k, True), + (tvm.tir.all(i <= j, j <= k), i <= k, True), + (tvm.tir.all(i <= j, j <= k), i < k, False), + (tvm.tir.all(i <= j, j > k), i < k, False), + (tvm.tir.all(i <= j, j >= k), i < k, False), + (tvm.tir.all(i <= j, j != k), i < k, False), + # Transitive comparisons from GT + (tvm.tir.all(i > j, j > k), i > k, True), + (tvm.tir.all(i > j, j == k), i > k, True), + (tvm.tir.all(i > j, j >= k), i > k, True), + (tvm.tir.all(i > j, j < k), i > k, False), + (tvm.tir.all(i > j, j <= k), i > k, False), + (tvm.tir.all(i > j, j != k), i > k, False), + # Transitive comparisons from GE + (tvm.tir.all(i >= j, j > k), i > k, True), + (tvm.tir.all(i >= j, j == k), i == k, False), + (tvm.tir.all(i >= j, j == k), i >= k, True), + (tvm.tir.all(i >= j, j >= k), i >= k, True), + (tvm.tir.all(i >= j, j >= k), i > k, False), + (tvm.tir.all(i >= j, j < k), i > k, False), + (tvm.tir.all(i >= j, j <= k), i > k, False), + (tvm.tir.all(i >= j, j != k), i > k, False), + # GT or LT may be used to prove NE + (tvm.tir.all(i == j, j != k), i != k, True), + (tvm.tir.all(i == j, j < k), i != k, True), + (tvm.tir.all(i == j, j > k), i != k, True), + (tvm.tir.all(i == j, j != k), i < k, False), + (tvm.tir.all(i == j, j != k), i > k, False), + # Because these are integers, x<y is equivalent to x <= y-1, + # and may be used in equivalent simplifications. + (tvm.tir.all(i < j, j < k), i < k, True), Review Comment: Duplicates [L577](https://github.com/apache/tvm/pull/12863/files#diff-d7436dae3a0ec5555c249400c293d8d035753562f13fddadc0ebadf4f4c0d997R577). I think you want `(tvm.tir.all(i <= j-1, j < k), i < k, True),`. ########## include/tvm/arith/analyzer.h: ########## @@ -317,6 +347,82 @@ class CanonicalSimplifier { Impl* impl_; }; +/*! \brief Structure for representing result of known + * + * Values are assigned to allow these flags to be used in bitwise + * operations. + */ +enum class CompareResult : int { + kInconsistent = 0, + kEQ = 1, + kLT = 2, + kLE = 3, + kGT = 4, + kGE = 5, + kNE = 6, + kUnknown = 7 +}; + +inline constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs) { + return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs)); +} +inline constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs) { + return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs)); +} + +/*! + * \brief Using previously specified knowns, compare the expressions provided + * + * Given known expressions [(a OP b), (b OP c), ..., (y OP z)], search + * for a known result for `(a OP z)`. + */ +class TransitiveComparisonAnalyzer { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs); + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false); Review Comment: TVM_DLL, here and elsewhere ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; Review Comment: This is functioning as an ID for an expr st once a relationship between two keys is established, e.g. equality, it can be looked up without needing to re-evaluate equality. Do I read your intent correctly? ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements + * + * For example, the condition of an IfThenElse, which is known to be + * true while within the if scope. + */ + std::vector<Comparison> scoped_knowns_; +}; + +namespace { + +// Internal utility, return the CompareResult resulting from swapping +// the left-hand side with the right-hand side. +CompareResult Reverse(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kEQ: + return CompareResult::kEQ; + case CompareResult::kLT: + return CompareResult::kGT; + case CompareResult::kLE: + return CompareResult::kGE; + case CompareResult::kGT: + return CompareResult::kLT; + case CompareResult::kGE: + return CompareResult::kLE; + case CompareResult::kNE: + return CompareResult::kNE; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); + return CompareResult::kInconsistent; + } +} + +// Internal utility, return the CompareResult resulting from negating +// the comparison. +CompareResult Negate(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); + } +} + +// Internal utility, extract constant offsets out of the two sides of +// a comparison. Given lhs and rhs, return a tuple of three elements +// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and +// (lhs_inner OP rhs_inner + offset) are equivalent. +std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) { + auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { + PVar<PrimExpr> x; + PVar<IntImm> c; + if ((x + c).Match(expr)) { + return {x.Eval(), c.Eval()->value}; + } else if ((x - c).Match(expr)) { + return {x.Eval(), -c.Eval()->value}; + } else if (c.Match(expr)) { + return {0, c.Eval()->value}; + } else { + return {expr, 0}; + } + }; + + auto lhs_split = extract_offset(lhs); + auto rhs_split = extract_offset(rhs); + return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; +} + +} // namespace + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { + CompareResult res; + PVar<PrimExpr> x, y; + if ((x <= y).Match(expr)) { + res = CompareResult::kLE; + } else if ((x >= y).Match(expr)) { + res = CompareResult::kGE; + } else if ((x < y).Match(expr)) { + res = CompareResult::kLT; + } else if ((x > y).Match(expr)) { + res = CompareResult::kGT; + } else if ((x == y).Match(expr)) { + res = CompareResult::kEQ; + } else if ((x != y).Match(expr)) { + res = CompareResult::kNE; + } else { + return std::nullopt; + } + + PrimExpr lhs_expr = x.Eval(); + PrimExpr rhs_expr = y.Eval(); + + if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { + return std::nullopt; + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + Key lhs_key = ExprToKey(lhs); + Key rhs_key = ExprToKey(rhs); + + return Comparison(lhs_key, rhs_key, offset, res); +} + +TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, + CompareResult result) + : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { + if (result_ == CompareResult::kLT) { + result_ = CompareResult::kLE; + offset_ -= 1; + } + if (result_ == CompareResult::kGT) { + result_ = CompareResult::kGE; + offset_ += 1; + } +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Key> +TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const { + auto it = expr_to_key.find(expr); + if (it != expr_to_key.end()) { + return it->second; + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey( + const PrimExpr& expr) { + if (auto prev = ExprToPreviousKey(expr)) { + return prev.value(); + } else { + Key new_key = Key(expr_to_key.size()); + expr_to_key[expr] = new_key; + return new_key; + } +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const { + // These < and > should be removed during normalization. + return result_ != CompareResult::kLT && result_ != CompareResult::kGT; +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const { + if (new_lhs == lhs_) { + return *this; + } else if (new_lhs == rhs_) { + return Comparison(rhs_, lhs_, -offset_, Reverse(result_)); + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Comparison +TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const { + return Comparison(lhs_, rhs_, offset_, Negate(result_)); +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( + const TransitiveComparisonAnalyzer::Impl::Comparison& other) const { + ICHECK(lhs_ == other.lhs_); + ICHECK(rhs_ == other.rhs_); + ICHECK(IsNormalized()); + ICHECK(other.IsNormalized()); + + if (result_ == other.result_ && offset_ == other.offset_) { + // if c1 == c2, x != y + c1 => x != y + c2 + // if c1 == c2, x == y + c1 => x == y + c2 + return true; + } + + if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) { + // if c1 <= c2, x <= y + c1 => x <= y + c2 + // if c1 <= c2, x == y + c1 => x <= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) { + // if c1 >= c2, x == y + c1 => x >= y + c2 + // if c1 >= c2, x >= y + c1 => x >= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kNE) { + if (result_ == CompareResult::kEQ && offset_ != other.offset_) { + // if c1 != c2, x == y + c1 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kLE && offset_ < other.offset_) { + // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kGE && offset_ > other.offset_) { + // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2 + return true; + } + } + + return false; +} + +TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {} +TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {} + +CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) { + return impl_->TryCompare(lhs, rhs); +} + +void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + impl_->Bind(var, expr, allow_override); +} +void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { + impl_->Bind(var, range, allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); +} + +void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, + std::vector<Comparison>* vec) { + for (const auto& subexpr : ExtractConstraints(expr)) { + if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) { + if (auto cmp = FromExpr(subexpr)) { + vec->push_back(cmp.value()); + } + } + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range, + bool allow_override) { + auto it = prev_bindings_.find(var); + if (it != prev_bindings_.end()) { + ExprDeepEqual expr_equal; + bool differs_from_previous = !expr_equal(range->min, (*it).second->min) || + !expr_equal(range->extent, (*it).second->extent); + if (differs_from_previous) { + ICHECK(allow_override) << "Binding of variable " << var << " as " << range + << " conflicts with previous binding as " << (*it).second; + if (auto key = ExprToPreviousKey(var)) { + knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(), + [&](const auto& known) { return known.lhs_ == key.value(); }), + knowns_.end()); + } + } + } + + prev_bindings_.Set(var, range); + + if (is_const_int(range->extent, 1)) { + AddKnown(var == range->min, &knowns_); + } else { + AddKnown(var >= range->min, &knowns_); + AddKnown(var < range->min + range->extent, &knowns_); + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr, + bool allow_override) { + Bind(var, Range::FromMinExtent(expr, 1), allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) { + size_t old_literal_size = scoped_knowns_.size(); + AddKnown(expr, &scoped_knowns_); + size_t new_literal_size = scoped_knowns_.size(); + + PrimExpr temp = expr; + auto frecover = [old_literal_size, new_literal_size, this, temp]() { Review Comment: nit: Can define `temp` here in the lambda capture list, `[..., temp = expr]() {...};` since it isn't used elsewhere. ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements + * + * For example, the condition of an IfThenElse, which is known to be + * true while within the if scope. + */ + std::vector<Comparison> scoped_knowns_; +}; + +namespace { + +// Internal utility, return the CompareResult resulting from swapping +// the left-hand side with the right-hand side. +CompareResult Reverse(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kEQ: + return CompareResult::kEQ; + case CompareResult::kLT: + return CompareResult::kGT; + case CompareResult::kLE: + return CompareResult::kGE; + case CompareResult::kGT: + return CompareResult::kLT; + case CompareResult::kGE: + return CompareResult::kLE; + case CompareResult::kNE: + return CompareResult::kNE; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); + return CompareResult::kInconsistent; + } +} + +// Internal utility, return the CompareResult resulting from negating +// the comparison. +CompareResult Negate(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); + } +} + +// Internal utility, extract constant offsets out of the two sides of +// a comparison. Given lhs and rhs, return a tuple of three elements +// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and +// (lhs_inner OP rhs_inner + offset) are equivalent. +std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) { + auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { + PVar<PrimExpr> x; + PVar<IntImm> c; + if ((x + c).Match(expr)) { + return {x.Eval(), c.Eval()->value}; + } else if ((x - c).Match(expr)) { + return {x.Eval(), -c.Eval()->value}; + } else if (c.Match(expr)) { + return {0, c.Eval()->value}; + } else { + return {expr, 0}; + } + }; + + auto lhs_split = extract_offset(lhs); + auto rhs_split = extract_offset(rhs); + return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; +} + +} // namespace + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { + CompareResult res; + PVar<PrimExpr> x, y; + if ((x <= y).Match(expr)) { + res = CompareResult::kLE; + } else if ((x >= y).Match(expr)) { + res = CompareResult::kGE; + } else if ((x < y).Match(expr)) { + res = CompareResult::kLT; + } else if ((x > y).Match(expr)) { + res = CompareResult::kGT; + } else if ((x == y).Match(expr)) { + res = CompareResult::kEQ; + } else if ((x != y).Match(expr)) { + res = CompareResult::kNE; + } else { + return std::nullopt; + } + + PrimExpr lhs_expr = x.Eval(); + PrimExpr rhs_expr = y.Eval(); + + if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { + return std::nullopt; + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + Key lhs_key = ExprToKey(lhs); + Key rhs_key = ExprToKey(rhs); + + return Comparison(lhs_key, rhs_key, offset, res); +} + +TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, + CompareResult result) + : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { + if (result_ == CompareResult::kLT) { + result_ = CompareResult::kLE; + offset_ -= 1; + } + if (result_ == CompareResult::kGT) { + result_ = CompareResult::kGE; + offset_ += 1; + } +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Key> +TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const { + auto it = expr_to_key.find(expr); + if (it != expr_to_key.end()) { + return it->second; + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey( + const PrimExpr& expr) { + if (auto prev = ExprToPreviousKey(expr)) { + return prev.value(); + } else { + Key new_key = Key(expr_to_key.size()); + expr_to_key[expr] = new_key; + return new_key; + } +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const { + // These < and > should be removed during normalization. + return result_ != CompareResult::kLT && result_ != CompareResult::kGT; +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const { + if (new_lhs == lhs_) { + return *this; + } else if (new_lhs == rhs_) { + return Comparison(rhs_, lhs_, -offset_, Reverse(result_)); + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Comparison +TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const { + return Comparison(lhs_, rhs_, offset_, Negate(result_)); +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( + const TransitiveComparisonAnalyzer::Impl::Comparison& other) const { + ICHECK(lhs_ == other.lhs_); + ICHECK(rhs_ == other.rhs_); + ICHECK(IsNormalized()); + ICHECK(other.IsNormalized()); + + if (result_ == other.result_ && offset_ == other.offset_) { + // if c1 == c2, x != y + c1 => x != y + c2 + // if c1 == c2, x == y + c1 => x == y + c2 + return true; + } + + if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) { + // if c1 <= c2, x <= y + c1 => x <= y + c2 + // if c1 <= c2, x == y + c1 => x <= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) { + // if c1 >= c2, x == y + c1 => x >= y + c2 + // if c1 >= c2, x >= y + c1 => x >= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kNE) { + if (result_ == CompareResult::kEQ && offset_ != other.offset_) { + // if c1 != c2, x == y + c1 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kLE && offset_ < other.offset_) { + // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kGE && offset_ > other.offset_) { + // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2 + return true; Review Comment: Should we also check the value of `other.result_` to ensure against an erroneous match, e.g. `other.result_ == CompareResult::kEQ` / `x == y + c2`? ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements + * + * For example, the condition of an IfThenElse, which is known to be + * true while within the if scope. + */ + std::vector<Comparison> scoped_knowns_; +}; + +namespace { + +// Internal utility, return the CompareResult resulting from swapping +// the left-hand side with the right-hand side. +CompareResult Reverse(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kEQ: + return CompareResult::kEQ; + case CompareResult::kLT: + return CompareResult::kGT; + case CompareResult::kLE: + return CompareResult::kGE; + case CompareResult::kGT: + return CompareResult::kLT; + case CompareResult::kGE: + return CompareResult::kLE; + case CompareResult::kNE: + return CompareResult::kNE; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); + return CompareResult::kInconsistent; + } +} + +// Internal utility, return the CompareResult resulting from negating +// the comparison. +CompareResult Negate(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); + } +} + +// Internal utility, extract constant offsets out of the two sides of +// a comparison. Given lhs and rhs, return a tuple of three elements +// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and +// (lhs_inner OP rhs_inner + offset) are equivalent. +std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) { + auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { + PVar<PrimExpr> x; + PVar<IntImm> c; + if ((x + c).Match(expr)) { + return {x.Eval(), c.Eval()->value}; + } else if ((x - c).Match(expr)) { + return {x.Eval(), -c.Eval()->value}; + } else if (c.Match(expr)) { + return {0, c.Eval()->value}; + } else { + return {expr, 0}; + } + }; + + auto lhs_split = extract_offset(lhs); + auto rhs_split = extract_offset(rhs); + return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; +} + +} // namespace + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { + CompareResult res; + PVar<PrimExpr> x, y; + if ((x <= y).Match(expr)) { + res = CompareResult::kLE; + } else if ((x >= y).Match(expr)) { + res = CompareResult::kGE; + } else if ((x < y).Match(expr)) { + res = CompareResult::kLT; + } else if ((x > y).Match(expr)) { + res = CompareResult::kGT; + } else if ((x == y).Match(expr)) { + res = CompareResult::kEQ; + } else if ((x != y).Match(expr)) { + res = CompareResult::kNE; + } else { + return std::nullopt; + } + + PrimExpr lhs_expr = x.Eval(); + PrimExpr rhs_expr = y.Eval(); + + if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { + return std::nullopt; + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + Key lhs_key = ExprToKey(lhs); + Key rhs_key = ExprToKey(rhs); + + return Comparison(lhs_key, rhs_key, offset, res); +} + +TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, + CompareResult result) + : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { + if (result_ == CompareResult::kLT) { + result_ = CompareResult::kLE; + offset_ -= 1; + } + if (result_ == CompareResult::kGT) { + result_ = CompareResult::kGE; + offset_ += 1; + } +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Key> +TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const { + auto it = expr_to_key.find(expr); + if (it != expr_to_key.end()) { + return it->second; + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey( + const PrimExpr& expr) { + if (auto prev = ExprToPreviousKey(expr)) { + return prev.value(); + } else { + Key new_key = Key(expr_to_key.size()); + expr_to_key[expr] = new_key; + return new_key; + } +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const { + // These < and > should be removed during normalization. + return result_ != CompareResult::kLT && result_ != CompareResult::kGT; +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const { + if (new_lhs == lhs_) { + return *this; + } else if (new_lhs == rhs_) { + return Comparison(rhs_, lhs_, -offset_, Reverse(result_)); + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Comparison +TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const { + return Comparison(lhs_, rhs_, offset_, Negate(result_)); +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( + const TransitiveComparisonAnalyzer::Impl::Comparison& other) const { + ICHECK(lhs_ == other.lhs_); + ICHECK(rhs_ == other.rhs_); + ICHECK(IsNormalized()); + ICHECK(other.IsNormalized()); + + if (result_ == other.result_ && offset_ == other.offset_) { + // if c1 == c2, x != y + c1 => x != y + c2 + // if c1 == c2, x == y + c1 => x == y + c2 + return true; + } + + if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) { + // if c1 <= c2, x <= y + c1 => x <= y + c2 + // if c1 <= c2, x == y + c1 => x <= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) { + // if c1 >= c2, x == y + c1 => x >= y + c2 + // if c1 >= c2, x >= y + c1 => x >= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kNE) { + if (result_ == CompareResult::kEQ && offset_ != other.offset_) { + // if c1 != c2, x == y + c1 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kLE && offset_ < other.offset_) { + // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kGE && offset_ > other.offset_) { + // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2 + return true; + } + } + + return false; +} + +TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {} +TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {} + +CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) { + return impl_->TryCompare(lhs, rhs); +} + +void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + impl_->Bind(var, expr, allow_override); +} +void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { + impl_->Bind(var, range, allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); +} + +void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, + std::vector<Comparison>* vec) { + for (const auto& subexpr : ExtractConstraints(expr)) { + if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) { + if (auto cmp = FromExpr(subexpr)) { + vec->push_back(cmp.value()); + } + } + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range, + bool allow_override) { + auto it = prev_bindings_.find(var); + if (it != prev_bindings_.end()) { + ExprDeepEqual expr_equal; + bool differs_from_previous = !expr_equal(range->min, (*it).second->min) || + !expr_equal(range->extent, (*it).second->extent); + if (differs_from_previous) { + ICHECK(allow_override) << "Binding of variable " << var << " as " << range + << " conflicts with previous binding as " << (*it).second; + if (auto key = ExprToPreviousKey(var)) { + knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(), + [&](const auto& known) { return known.lhs_ == key.value(); }), + knowns_.end()); + } + } + } + + prev_bindings_.Set(var, range); + + if (is_const_int(range->extent, 1)) { + AddKnown(var == range->min, &knowns_); + } else { + AddKnown(var >= range->min, &knowns_); + AddKnown(var < range->min + range->extent, &knowns_); + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr, + bool allow_override) { + Bind(var, Range::FromMinExtent(expr, 1), allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) { + size_t old_literal_size = scoped_knowns_.size(); + AddKnown(expr, &scoped_knowns_); + size_t new_literal_size = scoped_knowns_.size(); + + PrimExpr temp = expr; + auto frecover = [old_literal_size, new_literal_size, this, temp]() { + ICHECK_EQ(scoped_knowns_.size(), new_literal_size); + scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end()); + }; + return frecover; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr, + const PrimExpr& rhs_expr) const { + // Currently only supports integer checks + if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) { + return CompareResult::kUnknown; + } + + // Bail out early if possible. This int check should have been + // constant-folded earlier, so this check shouldn't occur. + auto* x_int = lhs_expr.as<IntImmNode>(); + auto* y_int = rhs_expr.as<IntImmNode>(); + if (x_int && y_int) { + if (x_int->value < y_int->value) { + return CompareResult::kLT; + } else if (x_int->value > y_int->value) { + return CompareResult::kGT; + } else { + return CompareResult::kEQ; + } + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + auto lhs_key = ExprToPreviousKey(lhs); + auto rhs_key = ExprToPreviousKey(rhs); + + if (!lhs_key.has_value() || !rhs_key.has_value()) { + return CompareResult::kUnknown; + } + + auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs); + auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs)); + auto output = from_lhs & from_rhs; + + return output; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS( + Key lhs_key_input, Key rhs_key_input, int64_t offset_input, const PrimExpr& lhs_input, + const PrimExpr& rhs_input) const { + Key lhs_key = lhs_key_input; + Key rhs_key = rhs_key_input; + int64_t offset = offset_input; + + // Everything in `to_visit` has lhs as its lhs. + std::unordered_set<Key> seen; + std::unordered_set<Key> to_visit; + std::unordered_map<Key, std::vector<Comparison>> compared_to_x; + + // Utility function to add a new known statement + auto declare_known = [&](Comparison cmp) { + auto& prev_knowns = compared_to_x[cmp.rhs_]; + + for (auto& prev_known : prev_knowns) { + if (prev_known.Implies(cmp)) { + return; + } + } + + if (cmp.rhs_ != rhs_key && !seen.count(cmp.rhs_)) { + to_visit.insert(cmp.rhs_); + seen.insert(cmp.rhs_); + } + + for (auto& prev_known : prev_knowns) { + if (cmp.Implies(prev_known)) { + prev_known = cmp; + return; + } + } + + prev_knowns.push_back(cmp); + }; + + // Initialize the search based on any known (in)equalities that use + // the LHS of the comparison. + for (const auto& known : knowns_) { + if (auto normalized = known.WithLHS(lhs_key)) { + declare_known(normalized.value()); + } + } + for (const auto& known : scoped_knowns_) { + if (auto normalized = known.WithLHS(lhs_key)) { + declare_known(normalized.value()); + } + } + + // Walk through the space of all comparisons that can be made with + // LHS. + while (to_visit.size()) { + Key middle_key = *to_visit.begin(); + to_visit.erase(to_visit.begin()); + + std::vector<Comparison>& prev_knowns_using_middle = compared_to_x.at(middle_key); + ICHECK(compared_to_x.count(middle_key)); + + std::vector<Comparison> new_knowns_using_lhs; + + auto attempt_transitive = [&](Comparison cmp) { + ICHECK(cmp.IsNormalized()); + + Key right_key = cmp.rhs_; + + if (right_key == lhs_key) { + return; + } + + for (const auto& prev : prev_knowns_using_middle) { + CompareResult new_result = CompareResult::kUnknown; + int64_t new_offset = prev.offset_ + cmp.offset_; + + if (prev.result_ == CompareResult::kEQ) { + // x == y + c1 && y OP z + c2, x OP z + (c1 + c2) + new_result = cmp.result_; + } else if (cmp.result_ == CompareResult::kEQ) { + // x OP y + c1 && y == z + c2, x OP z + (c1 + c2) + new_result = prev.result_; + } else if (prev.result_ == cmp.result_ && + (prev.result_ == CompareResult::kLE || prev.result_ == CompareResult::kGE)) { + // x <= y + c1 && y <= z + c2, x <= z + (c1 + c2) + // x >= y + c1 && y >= z + c2, x >= z + (c1 + c2) + // + // This condition is much simpler to write than the + // equivalent handling of < or of >, which is why the + // inequalities are normalized to <= and to >=. + new_result = prev.result_; + } + + if (new_result != CompareResult::kUnknown) { + Comparison new_known(lhs_key, right_key, new_offset, new_result); + new_knowns_using_lhs.push_back(new_known); + } + } + }; + + // Attempt to prove a new comparison using one of the original + // known comparisons. We want to find a known such that + // `(LHS OP1 middle) && (middle OP2 right)` can be simplified Review Comment: nit: Switch of convention from LHS to right instead of RHS. Maybe just use left/middle/right ########## include/tvm/arith/analyzer.h: ########## @@ -275,6 +275,36 @@ class RewriteSimplifier { */ std::function<void()> EnterConstraint(const PrimExpr& constraint); + /*! \brief Flags to enable more computationally-intensive simplifications + * + * These simplifications may be required for specific schedules, but + * would impose too high a compile-time cost to enable by default. + * They can be enabled on an as-needed basis by calling + * `RewriteSimplifier::SetEnabledFeatures` prior to using + * `RewriteSimplifier::operator()`. + */ + enum Feature { + // No features enabled + kNone = 0, + + /* When simplifying an inequality, attempt to use scope-based knowns. + * + * Example: + * if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false) + */ + kTransitivelyProveInequalities = (1 << 0), + }; + + /*! \brief Enable an optional feature or features + * + * \param flags A bitwise OR of all optional features that should be + * enabled. + */ + void SetEnabledFeatures(Feature flags); Review Comment: Use of TVM_DLL on member functions. ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a Review Comment: nit: A diagram of the AST referenced by `node edge` in this line would make the following description quite clear. ########## include/tvm/arith/analyzer.h: ########## @@ -275,6 +275,36 @@ class RewriteSimplifier { */ std::function<void()> EnterConstraint(const PrimExpr& constraint); + /*! \brief Flags to enable more computationally-intensive simplifications + * + * These simplifications may be required for specific schedules, but + * would impose too high a compile-time cost to enable by default. + * They can be enabled on an as-needed basis by calling + * `RewriteSimplifier::SetEnabledFeatures` prior to using + * `RewriteSimplifier::operator()`. + */ + enum Feature { Review Comment: Extensions? ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements Review Comment: Unclear sentence ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements + * + * For example, the condition of an IfThenElse, which is known to be + * true while within the if scope. + */ + std::vector<Comparison> scoped_knowns_; +}; + +namespace { + +// Internal utility, return the CompareResult resulting from swapping +// the left-hand side with the right-hand side. +CompareResult Reverse(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kEQ: + return CompareResult::kEQ; + case CompareResult::kLT: + return CompareResult::kGT; + case CompareResult::kLE: + return CompareResult::kGE; + case CompareResult::kGT: + return CompareResult::kLT; + case CompareResult::kGE: + return CompareResult::kLE; + case CompareResult::kNE: + return CompareResult::kNE; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); + return CompareResult::kInconsistent; + } +} + +// Internal utility, return the CompareResult resulting from negating +// the comparison. +CompareResult Negate(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); + } +} + +// Internal utility, extract constant offsets out of the two sides of +// a comparison. Given lhs and rhs, return a tuple of three elements +// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and +// (lhs_inner OP rhs_inner + offset) are equivalent. +std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) { + auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { + PVar<PrimExpr> x; + PVar<IntImm> c; + if ((x + c).Match(expr)) { + return {x.Eval(), c.Eval()->value}; + } else if ((x - c).Match(expr)) { + return {x.Eval(), -c.Eval()->value}; + } else if (c.Match(expr)) { + return {0, c.Eval()->value}; + } else { + return {expr, 0}; + } + }; + + auto lhs_split = extract_offset(lhs); + auto rhs_split = extract_offset(rhs); + return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; +} + +} // namespace + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { + CompareResult res; + PVar<PrimExpr> x, y; + if ((x <= y).Match(expr)) { + res = CompareResult::kLE; + } else if ((x >= y).Match(expr)) { + res = CompareResult::kGE; + } else if ((x < y).Match(expr)) { + res = CompareResult::kLT; + } else if ((x > y).Match(expr)) { + res = CompareResult::kGT; + } else if ((x == y).Match(expr)) { + res = CompareResult::kEQ; + } else if ((x != y).Match(expr)) { + res = CompareResult::kNE; + } else { + return std::nullopt; + } + + PrimExpr lhs_expr = x.Eval(); + PrimExpr rhs_expr = y.Eval(); + + if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { + return std::nullopt; + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + Key lhs_key = ExprToKey(lhs); + Key rhs_key = ExprToKey(rhs); + + return Comparison(lhs_key, rhs_key, offset, res); +} + +TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, + CompareResult result) + : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { + if (result_ == CompareResult::kLT) { + result_ = CompareResult::kLE; + offset_ -= 1; Review Comment: Any comment on why this representation is beneficial? E.g. to normalize, but perhaps a brief description on IsNormalized can provide clarity. ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements + * + * For example, the condition of an IfThenElse, which is known to be + * true while within the if scope. + */ + std::vector<Comparison> scoped_knowns_; +}; + +namespace { + +// Internal utility, return the CompareResult resulting from swapping +// the left-hand side with the right-hand side. +CompareResult Reverse(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kEQ: + return CompareResult::kEQ; + case CompareResult::kLT: + return CompareResult::kGT; + case CompareResult::kLE: + return CompareResult::kGE; + case CompareResult::kGT: + return CompareResult::kLT; + case CompareResult::kGE: + return CompareResult::kLE; + case CompareResult::kNE: + return CompareResult::kNE; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); + return CompareResult::kInconsistent; + } +} + +// Internal utility, return the CompareResult resulting from negating +// the comparison. +CompareResult Negate(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); + } +} + +// Internal utility, extract constant offsets out of the two sides of +// a comparison. Given lhs and rhs, return a tuple of three elements +// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and +// (lhs_inner OP rhs_inner + offset) are equivalent. +std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) { + auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { + PVar<PrimExpr> x; + PVar<IntImm> c; + if ((x + c).Match(expr)) { + return {x.Eval(), c.Eval()->value}; + } else if ((x - c).Match(expr)) { + return {x.Eval(), -c.Eval()->value}; + } else if (c.Match(expr)) { + return {0, c.Eval()->value}; + } else { + return {expr, 0}; + } + }; + + auto lhs_split = extract_offset(lhs); + auto rhs_split = extract_offset(rhs); + return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; +} + +} // namespace + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { + CompareResult res; + PVar<PrimExpr> x, y; + if ((x <= y).Match(expr)) { + res = CompareResult::kLE; + } else if ((x >= y).Match(expr)) { + res = CompareResult::kGE; + } else if ((x < y).Match(expr)) { + res = CompareResult::kLT; + } else if ((x > y).Match(expr)) { + res = CompareResult::kGT; + } else if ((x == y).Match(expr)) { + res = CompareResult::kEQ; + } else if ((x != y).Match(expr)) { + res = CompareResult::kNE; + } else { + return std::nullopt; + } + + PrimExpr lhs_expr = x.Eval(); + PrimExpr rhs_expr = y.Eval(); + + if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { + return std::nullopt; + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + Key lhs_key = ExprToKey(lhs); + Key rhs_key = ExprToKey(rhs); + + return Comparison(lhs_key, rhs_key, offset, res); +} + +TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, + CompareResult result) + : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { + if (result_ == CompareResult::kLT) { + result_ = CompareResult::kLE; + offset_ -= 1; + } + if (result_ == CompareResult::kGT) { + result_ = CompareResult::kGE; + offset_ += 1; + } +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Key> +TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const { + auto it = expr_to_key.find(expr); + if (it != expr_to_key.end()) { + return it->second; + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey( + const PrimExpr& expr) { + if (auto prev = ExprToPreviousKey(expr)) { + return prev.value(); + } else { + Key new_key = Key(expr_to_key.size()); + expr_to_key[expr] = new_key; + return new_key; + } +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const { + // These < and > should be removed during normalization. + return result_ != CompareResult::kLT && result_ != CompareResult::kGT; +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const { + if (new_lhs == lhs_) { + return *this; + } else if (new_lhs == rhs_) { + return Comparison(rhs_, lhs_, -offset_, Reverse(result_)); + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Comparison +TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const { + return Comparison(lhs_, rhs_, offset_, Negate(result_)); +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( + const TransitiveComparisonAnalyzer::Impl::Comparison& other) const { + ICHECK(lhs_ == other.lhs_); + ICHECK(rhs_ == other.rhs_); + ICHECK(IsNormalized()); + ICHECK(other.IsNormalized()); + + if (result_ == other.result_ && offset_ == other.offset_) { + // if c1 == c2, x != y + c1 => x != y + c2 + // if c1 == c2, x == y + c1 => x == y + c2 + return true; + } + + if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) { + // if c1 <= c2, x <= y + c1 => x <= y + c2 + // if c1 <= c2, x == y + c1 => x <= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) { + // if c1 >= c2, x == y + c1 => x >= y + c2 + // if c1 >= c2, x >= y + c1 => x >= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kNE) { + if (result_ == CompareResult::kEQ && offset_ != other.offset_) { + // if c1 != c2, x == y + c1 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kLE && offset_ < other.offset_) { + // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kGE && offset_ > other.offset_) { + // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2 + return true; + } + } + + return false; +} + +TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {} +TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {} + +CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) { + return impl_->TryCompare(lhs, rhs); +} + +void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + impl_->Bind(var, expr, allow_override); +} +void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { + impl_->Bind(var, range, allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); +} + +void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, + std::vector<Comparison>* vec) { + for (const auto& subexpr : ExtractConstraints(expr)) { + if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) { + if (auto cmp = FromExpr(subexpr)) { + vec->push_back(cmp.value()); + } + } + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range, + bool allow_override) { + auto it = prev_bindings_.find(var); + if (it != prev_bindings_.end()) { + ExprDeepEqual expr_equal; + bool differs_from_previous = !expr_equal(range->min, (*it).second->min) || + !expr_equal(range->extent, (*it).second->extent); + if (differs_from_previous) { + ICHECK(allow_override) << "Binding of variable " << var << " as " << range + << " conflicts with previous binding as " << (*it).second; + if (auto key = ExprToPreviousKey(var)) { + knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(), + [&](const auto& known) { return known.lhs_ == key.value(); }), + knowns_.end()); + } + } + } + + prev_bindings_.Set(var, range); + + if (is_const_int(range->extent, 1)) { + AddKnown(var == range->min, &knowns_); + } else { + AddKnown(var >= range->min, &knowns_); + AddKnown(var < range->min + range->extent, &knowns_); + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr, + bool allow_override) { + Bind(var, Range::FromMinExtent(expr, 1), allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) { + size_t old_literal_size = scoped_knowns_.size(); + AddKnown(expr, &scoped_knowns_); + size_t new_literal_size = scoped_knowns_.size(); + + PrimExpr temp = expr; + auto frecover = [old_literal_size, new_literal_size, this, temp]() { + ICHECK_EQ(scoped_knowns_.size(), new_literal_size); + scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end()); + }; + return frecover; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr, + const PrimExpr& rhs_expr) const { + // Currently only supports integer checks + if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) { + return CompareResult::kUnknown; + } + + // Bail out early if possible. This int check should have been + // constant-folded earlier, so this check shouldn't occur. + auto* x_int = lhs_expr.as<IntImmNode>(); + auto* y_int = rhs_expr.as<IntImmNode>(); + if (x_int && y_int) { + if (x_int->value < y_int->value) { + return CompareResult::kLT; + } else if (x_int->value > y_int->value) { + return CompareResult::kGT; + } else { + return CompareResult::kEQ; + } + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + auto lhs_key = ExprToPreviousKey(lhs); + auto rhs_key = ExprToPreviousKey(rhs); + + if (!lhs_key.has_value() || !rhs_key.has_value()) { + return CompareResult::kUnknown; + } + + auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs); + auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs)); + auto output = from_lhs & from_rhs; + + return output; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS( + Key lhs_key_input, Key rhs_key_input, int64_t offset_input, const PrimExpr& lhs_input, + const PrimExpr& rhs_input) const { + Key lhs_key = lhs_key_input; + Key rhs_key = rhs_key_input; + int64_t offset = offset_input; + + // Everything in `to_visit` has lhs as its lhs. + std::unordered_set<Key> seen; + std::unordered_set<Key> to_visit; + std::unordered_map<Key, std::vector<Comparison>> compared_to_x; + + // Utility function to add a new known statement + auto declare_known = [&](Comparison cmp) { + auto& prev_knowns = compared_to_x[cmp.rhs_]; Review Comment: Scratched my head for a while on whether `compared_to_x` always only contained default initialized vectors until I noticed you are updating the map value by reference. It maybe could have help if the type used was `std::vector<Comparison>&` to call attention to the container 🤷 ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements + * + * For example, the condition of an IfThenElse, which is known to be + * true while within the if scope. + */ + std::vector<Comparison> scoped_knowns_; +}; + +namespace { + +// Internal utility, return the CompareResult resulting from swapping +// the left-hand side with the right-hand side. +CompareResult Reverse(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kEQ: + return CompareResult::kEQ; + case CompareResult::kLT: + return CompareResult::kGT; + case CompareResult::kLE: + return CompareResult::kGE; + case CompareResult::kGT: + return CompareResult::kLT; + case CompareResult::kGE: + return CompareResult::kLE; + case CompareResult::kNE: + return CompareResult::kNE; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); + return CompareResult::kInconsistent; + } +} + +// Internal utility, return the CompareResult resulting from negating +// the comparison. +CompareResult Negate(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); + } +} + +// Internal utility, extract constant offsets out of the two sides of +// a comparison. Given lhs and rhs, return a tuple of three elements +// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and +// (lhs_inner OP rhs_inner + offset) are equivalent. +std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) { + auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { + PVar<PrimExpr> x; + PVar<IntImm> c; + if ((x + c).Match(expr)) { + return {x.Eval(), c.Eval()->value}; + } else if ((x - c).Match(expr)) { + return {x.Eval(), -c.Eval()->value}; + } else if (c.Match(expr)) { + return {0, c.Eval()->value}; + } else { + return {expr, 0}; + } + }; + + auto lhs_split = extract_offset(lhs); + auto rhs_split = extract_offset(rhs); + return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; +} + +} // namespace + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { + CompareResult res; + PVar<PrimExpr> x, y; + if ((x <= y).Match(expr)) { + res = CompareResult::kLE; + } else if ((x >= y).Match(expr)) { + res = CompareResult::kGE; + } else if ((x < y).Match(expr)) { + res = CompareResult::kLT; + } else if ((x > y).Match(expr)) { + res = CompareResult::kGT; + } else if ((x == y).Match(expr)) { + res = CompareResult::kEQ; + } else if ((x != y).Match(expr)) { + res = CompareResult::kNE; + } else { + return std::nullopt; + } + + PrimExpr lhs_expr = x.Eval(); + PrimExpr rhs_expr = y.Eval(); + + if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { + return std::nullopt; + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + Key lhs_key = ExprToKey(lhs); + Key rhs_key = ExprToKey(rhs); + + return Comparison(lhs_key, rhs_key, offset, res); +} + +TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, + CompareResult result) + : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { + if (result_ == CompareResult::kLT) { + result_ = CompareResult::kLE; + offset_ -= 1; + } + if (result_ == CompareResult::kGT) { + result_ = CompareResult::kGE; + offset_ += 1; + } +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Key> +TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const { + auto it = expr_to_key.find(expr); + if (it != expr_to_key.end()) { + return it->second; + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey( + const PrimExpr& expr) { + if (auto prev = ExprToPreviousKey(expr)) { + return prev.value(); + } else { + Key new_key = Key(expr_to_key.size()); + expr_to_key[expr] = new_key; + return new_key; + } +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const { + // These < and > should be removed during normalization. + return result_ != CompareResult::kLT && result_ != CompareResult::kGT; +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const { + if (new_lhs == lhs_) { + return *this; + } else if (new_lhs == rhs_) { + return Comparison(rhs_, lhs_, -offset_, Reverse(result_)); + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Comparison +TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const { + return Comparison(lhs_, rhs_, offset_, Negate(result_)); +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( + const TransitiveComparisonAnalyzer::Impl::Comparison& other) const { + ICHECK(lhs_ == other.lhs_); + ICHECK(rhs_ == other.rhs_); + ICHECK(IsNormalized()); + ICHECK(other.IsNormalized()); + + if (result_ == other.result_ && offset_ == other.offset_) { + // if c1 == c2, x != y + c1 => x != y + c2 + // if c1 == c2, x == y + c1 => x == y + c2 + return true; + } + + if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) { + // if c1 <= c2, x <= y + c1 => x <= y + c2 + // if c1 <= c2, x == y + c1 => x <= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) { + // if c1 >= c2, x == y + c1 => x >= y + c2 + // if c1 >= c2, x >= y + c1 => x >= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kNE) { + if (result_ == CompareResult::kEQ && offset_ != other.offset_) { + // if c1 != c2, x == y + c1 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kLE && offset_ < other.offset_) { + // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kGE && offset_ > other.offset_) { + // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2 + return true; + } + } + + return false; +} + +TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {} +TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {} + +CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) { + return impl_->TryCompare(lhs, rhs); +} + +void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + impl_->Bind(var, expr, allow_override); +} +void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { + impl_->Bind(var, range, allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); +} + +void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, + std::vector<Comparison>* vec) { + for (const auto& subexpr : ExtractConstraints(expr)) { + if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) { + if (auto cmp = FromExpr(subexpr)) { + vec->push_back(cmp.value()); + } + } + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range, + bool allow_override) { + auto it = prev_bindings_.find(var); + if (it != prev_bindings_.end()) { + ExprDeepEqual expr_equal; + bool differs_from_previous = !expr_equal(range->min, (*it).second->min) || + !expr_equal(range->extent, (*it).second->extent); + if (differs_from_previous) { + ICHECK(allow_override) << "Binding of variable " << var << " as " << range + << " conflicts with previous binding as " << (*it).second; + if (auto key = ExprToPreviousKey(var)) { + knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(), + [&](const auto& known) { return known.lhs_ == key.value(); }), + knowns_.end()); + } + } + } + + prev_bindings_.Set(var, range); + + if (is_const_int(range->extent, 1)) { + AddKnown(var == range->min, &knowns_); + } else { + AddKnown(var >= range->min, &knowns_); + AddKnown(var < range->min + range->extent, &knowns_); + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr, + bool allow_override) { + Bind(var, Range::FromMinExtent(expr, 1), allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) { + size_t old_literal_size = scoped_knowns_.size(); + AddKnown(expr, &scoped_knowns_); + size_t new_literal_size = scoped_knowns_.size(); + + PrimExpr temp = expr; + auto frecover = [old_literal_size, new_literal_size, this, temp]() { + ICHECK_EQ(scoped_knowns_.size(), new_literal_size); + scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end()); + }; + return frecover; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr, + const PrimExpr& rhs_expr) const { + // Currently only supports integer checks + if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) { + return CompareResult::kUnknown; + } + + // Bail out early if possible. This int check should have been + // constant-folded earlier, so this check shouldn't occur. + auto* x_int = lhs_expr.as<IntImmNode>(); + auto* y_int = rhs_expr.as<IntImmNode>(); + if (x_int && y_int) { + if (x_int->value < y_int->value) { + return CompareResult::kLT; + } else if (x_int->value > y_int->value) { + return CompareResult::kGT; + } else { + return CompareResult::kEQ; + } + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + auto lhs_key = ExprToPreviousKey(lhs); + auto rhs_key = ExprToPreviousKey(rhs); + + if (!lhs_key.has_value() || !rhs_key.has_value()) { + return CompareResult::kUnknown; + } + + auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs); + auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs)); + auto output = from_lhs & from_rhs; + + return output; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS( Review Comment: nit: TryCompareFromLHS is a bit long ########## src/arith/transitive_comparison_analyzer.cc: ########## @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/arith/transitive_comparison_analyzer.cc + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr.h> + +#include <optional> +#include <vector> + +#include "constraint_extract.h" +#include "pattern_match.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +class TransitiveComparisonAnalyzer::Impl { + public: + /* \brief Using previously specified knowns, compare the expressions provided + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The most specific result that can be proven about the + * comparison. If nothing can be proven, returns kUnknown. + */ + CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) const; + + /*! \brief Bind a variable as being equal to a known expression + * + * \param var The variable of interest. + * \param expr The bound expression + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! \brief Bind a variable as being within a specified range + * + * \param var The variable of interest. + * \param range The known range + * \param allow_override Whether to allow override of existing information. + */ + void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + + /*! + * \brief Update the internal state to enter constraint. + * \param constraint A constraint expression. + * + * \return An exit function that must be called to cleanup. May be + * `nullptr`, if no cleanup is required. + */ + std::function<void()> EnterConstraint(const PrimExpr& expr); + + private: + // Utility class to avoid needing to repeatedly call ExprDeepEqual + enum class Key : size_t {}; + std::optional<Key> ExprToPreviousKey(const PrimExpr& expr) const; + Key ExprToKey(const PrimExpr& expr); + std::unordered_map<PrimExpr, Key, StructuralHash, StructuralEqual> expr_to_key; + + /*! \brief Internal representation of a comparison operator */ + struct Comparison { + /*! \brief Construct a comparison that represents `lhs OP rhs + + * offset`, where the operation is specified by the CompareResult. + */ + Comparison(Key lhs, Key rhs, int64_t offset, CompareResult result); + + /*! \brief Utility function to validate that all GT and LT results + * have been normalized out + */ + bool IsNormalized() const; + + /*! \brief Move the specified expression to the LHS. + * + * \param new_lhs The argument that should be moved to the LHS of the + * comparison. + * + * \return If possible, returns a comparison that is equivalent to + * the current comparison, but with the specified LHS. If not + * possible, returns nullopt. + */ + std::optional<Comparison> WithLHS(Key new_lhs) const; + + /*! \brief Create the negation of the current comparison */ + Comparison Negated() const; + + /*! \brief Check the this comparison implies + * + * Returns true if this comparison being true implies that the + * other comparison must also be true. Returns false if the other + * comparison cannot be shown to be true. + */ + bool Implies(const Comparison& other) const; + + // The LHS of the comparison + Key lhs_; + + // The RHS of the comparison, not including any constant offset. + Key rhs_; + + // Additive offset on rhs + int64_t offset_{0}; + + // The comparison operator. + CompareResult result_{CompareResult::kInconsistent}; + }; + + /*! \brief Generate a Comparison representing the given expression */ + std::optional<Comparison> FromExpr(const PrimExpr& expr); + + /*! \brief Utility function used by Bind and EnterConstraint + * + * \param expr The comparison expression, to be converted into + * internal Comparison objects. + * + * \param vec The vector to which the Comparison objects should be + * appended. + */ + void AddKnown(const PrimExpr& expr, std::vector<Comparison>* vec); + + /*! \brief Attempt to compare, starting at the lhs. + * + * Taking each available `Comparison` as a node edge, search for a + * path from lhs to rhs. For example, the priors (a<=b), (b<=c+1) + * and (c<=d-5) can be used to prove that (a<=d-4). + * + * \param lhs The left-hand side of the comparison + * + * \param rhs The right-hand side of the comparison + * + * \return The result of the comparison + */ + CompareResult TryCompareFromLHS(Key lhs_key, Key rhs_key, int64_t offset, const PrimExpr& lhs, + const PrimExpr& rhs) const; + + /*! \brief Previous Range bindings + * + * Tracked separatedly to handle the `allow_override` option used by + * all sub-analyzers when binding variables. + */ + Map<Var, Range> prev_bindings_; + + /*! \brief Known comparisons based on definitionally-true statements + * + * For example, a Let binding, or the range of an iterator. + */ + std::vector<Comparison> knowns_; + + /*! \brief Known comparisons based on of scope-based statements + * + * For example, the condition of an IfThenElse, which is known to be + * true while within the if scope. + */ + std::vector<Comparison> scoped_knowns_; +}; + +namespace { + +// Internal utility, return the CompareResult resulting from swapping +// the left-hand side with the right-hand side. +CompareResult Reverse(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kEQ: + return CompareResult::kEQ; + case CompareResult::kLT: + return CompareResult::kGT; + case CompareResult::kLE: + return CompareResult::kGE; + case CompareResult::kGT: + return CompareResult::kLT; + case CompareResult::kGE: + return CompareResult::kLE; + case CompareResult::kNE: + return CompareResult::kNE; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + LOG(FATAL) << "Invalid CompareResult: " << static_cast<int>(res); + return CompareResult::kInconsistent; + } +} + +// Internal utility, return the CompareResult resulting from negating +// the comparison. +CompareResult Negate(CompareResult res) { + switch (res) { + case CompareResult::kInconsistent: + return CompareResult::kInconsistent; + case CompareResult::kUnknown: + return CompareResult::kUnknown; + default: + return CompareResult(~static_cast<int>(res) & static_cast<int>(CompareResult::kUnknown)); + } +} + +// Internal utility, extract constant offsets out of the two sides of +// a comparison. Given lhs and rhs, return a tuple of three elements +// (lhs_inner, rhs_inner, offset), such that (lhs OP rhs) and +// (lhs_inner OP rhs_inner + offset) are equivalent. +std::tuple<PrimExpr, PrimExpr, int64_t> ExtractOffsets(const PrimExpr& lhs, const PrimExpr& rhs) { + auto extract_offset = [](const PrimExpr& expr) -> std::pair<PrimExpr, int64_t> { + PVar<PrimExpr> x; + PVar<IntImm> c; + if ((x + c).Match(expr)) { + return {x.Eval(), c.Eval()->value}; + } else if ((x - c).Match(expr)) { + return {x.Eval(), -c.Eval()->value}; + } else if (c.Match(expr)) { + return {0, c.Eval()->value}; + } else { + return {expr, 0}; + } + }; + + auto lhs_split = extract_offset(lhs); + auto rhs_split = extract_offset(rhs); + return {lhs_split.first, rhs_split.first, rhs_split.second - lhs_split.second}; +} + +} // namespace + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::FromExpr(const PrimExpr& expr) { + CompareResult res; + PVar<PrimExpr> x, y; + if ((x <= y).Match(expr)) { + res = CompareResult::kLE; + } else if ((x >= y).Match(expr)) { + res = CompareResult::kGE; + } else if ((x < y).Match(expr)) { + res = CompareResult::kLT; + } else if ((x > y).Match(expr)) { + res = CompareResult::kGT; + } else if ((x == y).Match(expr)) { + res = CompareResult::kEQ; + } else if ((x != y).Match(expr)) { + res = CompareResult::kNE; + } else { + return std::nullopt; + } + + PrimExpr lhs_expr = x.Eval(); + PrimExpr rhs_expr = y.Eval(); + + if (lhs_expr.as<IntImmNode>() && rhs_expr.as<IntImmNode>()) { + return std::nullopt; + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + Key lhs_key = ExprToKey(lhs); + Key rhs_key = ExprToKey(rhs); + + return Comparison(lhs_key, rhs_key, offset, res); +} + +TransitiveComparisonAnalyzer::Impl::Comparison::Comparison(Key lhs, Key rhs, int64_t offset, + CompareResult result) + : lhs_(lhs), rhs_(rhs), offset_(offset), result_(result) { + if (result_ == CompareResult::kLT) { + result_ = CompareResult::kLE; + offset_ -= 1; + } + if (result_ == CompareResult::kGT) { + result_ = CompareResult::kGE; + offset_ += 1; + } +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Key> +TransitiveComparisonAnalyzer::Impl::ExprToPreviousKey(const PrimExpr& expr) const { + auto it = expr_to_key.find(expr); + if (it != expr_to_key.end()) { + return it->second; + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Key TransitiveComparisonAnalyzer::Impl::ExprToKey( + const PrimExpr& expr) { + if (auto prev = ExprToPreviousKey(expr)) { + return prev.value(); + } else { + Key new_key = Key(expr_to_key.size()); + expr_to_key[expr] = new_key; + return new_key; + } +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::IsNormalized() const { + // These < and > should be removed during normalization. + return result_ != CompareResult::kLT && result_ != CompareResult::kGT; +} + +std::optional<TransitiveComparisonAnalyzer::Impl::Comparison> +TransitiveComparisonAnalyzer::Impl::Comparison::WithLHS(Key new_lhs) const { + if (new_lhs == lhs_) { + return *this; + } else if (new_lhs == rhs_) { + return Comparison(rhs_, lhs_, -offset_, Reverse(result_)); + } else { + return std::nullopt; + } +} + +TransitiveComparisonAnalyzer::Impl::Comparison +TransitiveComparisonAnalyzer::Impl::Comparison::Negated() const { + return Comparison(lhs_, rhs_, offset_, Negate(result_)); +} + +bool TransitiveComparisonAnalyzer::Impl::Comparison::Implies( + const TransitiveComparisonAnalyzer::Impl::Comparison& other) const { + ICHECK(lhs_ == other.lhs_); + ICHECK(rhs_ == other.rhs_); + ICHECK(IsNormalized()); + ICHECK(other.IsNormalized()); + + if (result_ == other.result_ && offset_ == other.offset_) { + // if c1 == c2, x != y + c1 => x != y + c2 + // if c1 == c2, x == y + c1 => x == y + c2 + return true; + } + + if (other.result_ == CompareResult::kLE && offset_ <= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kLE) { + // if c1 <= c2, x <= y + c1 => x <= y + c2 + // if c1 <= c2, x == y + c1 => x <= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kGE && offset_ >= other.offset_) { + if (result_ == CompareResult::kEQ || result_ == CompareResult::kGE) { + // if c1 >= c2, x == y + c1 => x >= y + c2 + // if c1 >= c2, x >= y + c1 => x >= y + c2 + return true; + } + } + + if (other.result_ == CompareResult::kNE) { + if (result_ == CompareResult::kEQ && offset_ != other.offset_) { + // if c1 != c2, x == y + c1 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kLE && offset_ < other.offset_) { + // if c1 < c2, x <= y + c1 => x < y + c2 => x != y + c2 + return true; + } + + if (result_ == CompareResult::kGE && offset_ > other.offset_) { + // if c1 != c2, x >= y + c1 => x > y + c2 => x != y + c2 + return true; + } + } + + return false; +} + +TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : impl_(std::make_unique<Impl>()) {} +TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {} + +CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, const PrimExpr& rhs) { + return impl_->TryCompare(lhs, rhs); +} + +void TransitiveComparisonAnalyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + impl_->Bind(var, expr, allow_override); +} +void TransitiveComparisonAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) { + impl_->Bind(var, range, allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); +} + +void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, + std::vector<Comparison>* vec) { + for (const auto& subexpr : ExtractConstraints(expr)) { + if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) { + if (auto cmp = FromExpr(subexpr)) { + vec->push_back(cmp.value()); + } + } + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range, + bool allow_override) { + auto it = prev_bindings_.find(var); + if (it != prev_bindings_.end()) { + ExprDeepEqual expr_equal; + bool differs_from_previous = !expr_equal(range->min, (*it).second->min) || + !expr_equal(range->extent, (*it).second->extent); + if (differs_from_previous) { + ICHECK(allow_override) << "Binding of variable " << var << " as " << range + << " conflicts with previous binding as " << (*it).second; + if (auto key = ExprToPreviousKey(var)) { + knowns_.erase(std::remove_if(knowns_.begin(), knowns_.end(), + [&](const auto& known) { return known.lhs_ == key.value(); }), + knowns_.end()); + } + } + } + + prev_bindings_.Set(var, range); + + if (is_const_int(range->extent, 1)) { + AddKnown(var == range->min, &knowns_); + } else { + AddKnown(var >= range->min, &knowns_); + AddKnown(var < range->min + range->extent, &knowns_); + } +} + +void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr, + bool allow_override) { + Bind(var, Range::FromMinExtent(expr, 1), allow_override); +} + +std::function<void()> TransitiveComparisonAnalyzer::Impl::EnterConstraint(const PrimExpr& expr) { + size_t old_literal_size = scoped_knowns_.size(); + AddKnown(expr, &scoped_knowns_); + size_t new_literal_size = scoped_knowns_.size(); + + PrimExpr temp = expr; + auto frecover = [old_literal_size, new_literal_size, this, temp]() { + ICHECK_EQ(scoped_knowns_.size(), new_literal_size); + scoped_knowns_.erase(scoped_knowns_.begin() + old_literal_size, scoped_knowns_.end()); + }; + return frecover; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompare(const PrimExpr& lhs_expr, + const PrimExpr& rhs_expr) const { + // Currently only supports integer checks + if (!lhs_expr.dtype().is_int() || !rhs_expr.dtype().is_int()) { + return CompareResult::kUnknown; + } + + // Bail out early if possible. This int check should have been + // constant-folded earlier, so this check shouldn't occur. + auto* x_int = lhs_expr.as<IntImmNode>(); + auto* y_int = rhs_expr.as<IntImmNode>(); + if (x_int && y_int) { + if (x_int->value < y_int->value) { + return CompareResult::kLT; + } else if (x_int->value > y_int->value) { + return CompareResult::kGT; + } else { + return CompareResult::kEQ; + } + } + + auto [lhs, rhs, offset] = ExtractOffsets(lhs_expr, rhs_expr); + auto lhs_key = ExprToPreviousKey(lhs); + auto rhs_key = ExprToPreviousKey(rhs); + + if (!lhs_key.has_value() || !rhs_key.has_value()) { + return CompareResult::kUnknown; + } + + auto from_lhs = TryCompareFromLHS(lhs_key.value(), rhs_key.value(), offset, lhs, rhs); + auto from_rhs = Reverse(TryCompareFromLHS(rhs_key.value(), lhs_key.value(), -offset, rhs, lhs)); + auto output = from_lhs & from_rhs; + + return output; +} + +CompareResult TransitiveComparisonAnalyzer::Impl::TryCompareFromLHS( + Key lhs_key_input, Key rhs_key_input, int64_t offset_input, const PrimExpr& lhs_input, + const PrimExpr& rhs_input) const { + Key lhs_key = lhs_key_input; + Key rhs_key = rhs_key_input; + int64_t offset = offset_input; + + // Everything in `to_visit` has lhs as its lhs. + std::unordered_set<Key> seen; + std::unordered_set<Key> to_visit; + std::unordered_map<Key, std::vector<Comparison>> compared_to_x; + + // Utility function to add a new known statement + auto declare_known = [&](Comparison cmp) { + auto& prev_knowns = compared_to_x[cmp.rhs_]; + + for (auto& prev_known : prev_knowns) { + if (prev_known.Implies(cmp)) { + return; + } + } + + if (cmp.rhs_ != rhs_key && !seen.count(cmp.rhs_)) { + to_visit.insert(cmp.rhs_); + seen.insert(cmp.rhs_); + } + + for (auto& prev_known : prev_knowns) { + if (cmp.Implies(prev_known)) { + prev_known = cmp; + return; + } + } + + prev_knowns.push_back(cmp); + }; + + // Initialize the search based on any known (in)equalities that use + // the LHS of the comparison. + for (const auto& known : knowns_) { + if (auto normalized = known.WithLHS(lhs_key)) { + declare_known(normalized.value()); + } + } + for (const auto& known : scoped_knowns_) { + if (auto normalized = known.WithLHS(lhs_key)) { + declare_known(normalized.value()); + } + } + + // Walk through the space of all comparisons that can be made with + // LHS. + while (to_visit.size()) { + Key middle_key = *to_visit.begin(); + to_visit.erase(to_visit.begin()); + + std::vector<Comparison>& prev_knowns_using_middle = compared_to_x.at(middle_key); + ICHECK(compared_to_x.count(middle_key)); + + std::vector<Comparison> new_knowns_using_lhs; + + auto attempt_transitive = [&](Comparison cmp) { + ICHECK(cmp.IsNormalized()); + + Key right_key = cmp.rhs_; + + if (right_key == lhs_key) { + return; + } + + for (const auto& prev : prev_knowns_using_middle) { + CompareResult new_result = CompareResult::kUnknown; + int64_t new_offset = prev.offset_ + cmp.offset_; + + if (prev.result_ == CompareResult::kEQ) { + // x == y + c1 && y OP z + c2, x OP z + (c1 + c2) + new_result = cmp.result_; + } else if (cmp.result_ == CompareResult::kEQ) { + // x OP y + c1 && y == z + c2, x OP z + (c1 + c2) + new_result = prev.result_; + } else if (prev.result_ == cmp.result_ && + (prev.result_ == CompareResult::kLE || prev.result_ == CompareResult::kGE)) { + // x <= y + c1 && y <= z + c2, x <= z + (c1 + c2) + // x >= y + c1 && y >= z + c2, x >= z + (c1 + c2) + // + // This condition is much simpler to write than the + // equivalent handling of < or of >, which is why the + // inequalities are normalized to <= and to >=. Review Comment: Ah, here is the reasoning for such normalization! Some discussion of this in the docstring brief for IsNormalized would be great :). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org