================
@@ -0,0 +1,242 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "MissingEndComparisonCheck.h"
+#include "../utils/OptionsUtils.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/Lex/Lexer.h"
+#include "clang/Tooling/FixIt.h"
+
+using namespace clang::ast_matchers;
+
+namespace clang::tidy::bugprone {
+
+static constexpr llvm::StringRef IteratorAlgorithms[] = {
+    "::std::find",          "::std::find_if",
+    "::std::find_if_not",   "::std::search",
+    "::std::search_n",      "::std::find_end",
+    "::std::find_first_of", "::std::lower_bound",
+    "::std::upper_bound",   "::std::partition_point",
+    "::std::min_element",   "::std::max_element",
+    "::std::adjacent_find", "::std::is_sorted_until"};
+
+static constexpr llvm::StringRef RangeAlgorithms[] = {
+    "::std::ranges::find",          "::std::ranges::find_if",
+    "::std::ranges::find_if_not",   "::std::ranges::lower_bound",
+    "::std::ranges::upper_bound",   "::std::ranges::min_element",
+    "::std::ranges::max_element",   "::std::ranges::find_first_of",
+    "::std::ranges::adjacent_find", "::std::ranges::is_sorted_until"};
+
+MissingEndComparisonCheck::MissingEndComparisonCheck(StringRef Name,
+                                                     ClangTidyContext *Context)
+    : ClangTidyCheck(Name, Context),
+      ExtraAlgorithms(
+          utils::options::parseStringList(Options.get("ExtraAlgorithms", ""))) 
{
+}
+
+void MissingEndComparisonCheck::storeOptions(
+    ClangTidyOptions::OptionMap &Opts) {
+  Options.store(Opts, "ExtraAlgorithms",
+                utils::options::serializeStringList(ExtraAlgorithms));
+}
+
+static std::optional<std::string> getRangesEndText(const ASTContext &Context,
+                                                   const CallExpr *Call) {
+  const FunctionDecl *Callee = Call->getDirectCallee();
+  assert(Callee && Callee->getNumParams() > 0 &&
+         "Matcher should ensure Callee has parameters");
+
+  // Range overloads take a reference (R&&), Iterator overloads pass by value.
+  const bool IsIterPair =
+      !Callee->getParamDecl(0)->getType()->isReferenceType();
+
+  if (IsIterPair) {
+    if (Call->getNumArgs() < 3)
+      return std::nullopt;
+    // CPO(Iter, Sent, Val...) -> Sent is Arg 2.
+    const Expr *EndArg = Call->getArg(2);
+    return tooling::fixit::getText(*EndArg, Context).str();
+  }
+
+  if (Call->getNumArgs() < 2)
+    return std::nullopt;
+  // CPO(Range, Val, Proj) -> Range is Arg 1.
+  const Expr *RangeArg = Call->getArg(1);
+  // Avoid potential side-effects
+  const Expr *InnerRange = RangeArg->IgnoreParenImpCasts();
+  if (isa<DeclRefExpr, MemberExpr>(InnerRange)) {
+    const StringRef RangeText = tooling::fixit::getText(*RangeArg, Context);
+    if (!RangeText.empty())
+      return ("std::ranges::end(" + RangeText + ")").str();
+  }
+  return "";
+}
+
+static std::optional<std::string> getStandardEndText(ASTContext &Context,
+                                                     const CallExpr *Call) {
+  if (Call->getNumArgs() < 2)
+    return std::nullopt;
+
+  // Heuristic: if the first argument is a record type and the types of the
+  // first two arguments are distinct, we assume it's a range algorithm.
+  if (Call->getNumArgs() == 2) {
+    const Expr *Arg0 = Call->getArg(0);
+    const Expr *Arg1 = Call->getArg(1);
+    const QualType T0 = Arg0->getType().getCanonicalType();
+    const QualType T1 = Arg1->getType().getCanonicalType();
+
+    if (T0 != T1 && T0.getNonReferenceType()->isRecordType()) {
+      const StringRef ContainerText = tooling::fixit::getText(*Arg0, Context);
+      if (!ContainerText.empty())
+        return ("std::end(" + ContainerText + ")").str();
+    }
+  }
+
+  unsigned EndIdx = 1;
+  const Expr *FirstArg = Call->getArg(0);
+  if (const auto *Record =
+          FirstArg->getType().getNonReferenceType()->getAsCXXRecordDecl()) {
+    if (Record->getIdentifier() && Record->getName().ends_with("_policy"))
+      EndIdx = 2;
+  }
+
+  if (Call->getNumArgs() <= EndIdx)
+    return std::nullopt;
+
+  const Expr *EndArg = Call->getArg(EndIdx);
+  // Filters nullptr, we assume the intent might be a valid check against null
+  if (EndArg->IgnoreParenCasts()->isNullPointerConstant(
+          Context, Expr::NPC_ValueDependentIsNull))
+    return std::nullopt;
+
+  return tooling::fixit::getText(*EndArg, Context).str();
+}
+
+void MissingEndComparisonCheck::registerMatchers(MatchFinder *Finder) {
+  llvm::SmallVector<StringRef, 32> ExpandedIteratorAlgorithms;
+  ExpandedIteratorAlgorithms.append(std::begin(IteratorAlgorithms),
+                                    std::end(IteratorAlgorithms));
+  ExpandedIteratorAlgorithms.append(ExtraAlgorithms.begin(),
+                                    ExtraAlgorithms.end());
+
+  const auto StdAlgoCall = callExpr(callee(functionDecl(
+      hasAnyName(ExpandedIteratorAlgorithms), unless(parameterCountIs(0)))));
+
+  // Captures customization point object
+  const auto RangesCall = cxxOperatorCallExpr(
+      hasOverloadedOperatorName("()"),
+      callee(cxxMethodDecl(unless(parameterCountIs(0)))),
+      hasArgument(0, declRefExpr(to(
+                         varDecl(hasAnyName(RangeAlgorithms)).bind("cpo")))));
+
+  const auto AnyAlgoCall =
+      getLangOpts().CPlusPlus20
+          ? expr(anyOf(StdAlgoCall, RangesCall)).bind("algoCall")
+          : expr(StdAlgoCall).bind("algoCall");
+
+  // Captures implicit pointer-to-bool casts and operator bool() calls.
+  const auto IsBoolUsage = anyOf(
+      implicitCastExpr(hasCastKind(CK_PointerToBoolean),
+                       
hasSourceExpression(ignoringParenImpCasts(AnyAlgoCall))),
+      cxxMemberCallExpr(callee(cxxConversionDecl(returns(booleanType()))),
+                        on(ignoringParenImpCasts(AnyAlgoCall))));
+
+  // Captures variable usage: `auto it = std::find(...); if (it)`
+  // FIXME: This only handles variables initialized directly by the algorithm.
+  // We may need to introduce more accurate dataflow analysis in the future.
+  const auto VarWithAlgoInit =
+      varDecl(decl().bind("initVar"),
+              hasInitializer(expr(ignoringParenImpCasts(AnyAlgoCall))),
+              optionally(hasParent(declStmt(
+                  hasParent(mapAnyOf(ifStmt, whileStmt, forStmt)
+                                .with(hasConditionVariableStatement(declStmt(
+                                    has(varDecl(equalsBoundNode("initVar"))))))
+                                .bind("condVarParent"))))));
+
+  const auto IsVariableBoolUsage =
+      anyOf(implicitCastExpr(hasCastKind(CK_PointerToBoolean),
+                             hasSourceExpression(ignoringParenImpCasts(
+                                 declRefExpr(to(VarWithAlgoInit))))),
+            cxxMemberCallExpr(
+                callee(cxxConversionDecl(returns(booleanType()))),
+                on(ignoringParenImpCasts(declRefExpr(to(VarWithAlgoInit))))));
+
+  const auto BoolUsage = expr(anyOf(IsBoolUsage, IsVariableBoolUsage));
+
+  Finder->addMatcher(
+      unaryOperator(hasOperatorName("!"),
+                    hasUnaryOperand(ignoringParens(BoolUsage.bind("boolOp"))))
+          .bind("Neg"),
+      this);
+
+  Finder->addMatcher(
+      expr(BoolUsage,
+           unless(hasAncestor(unaryOperator(
+               hasOperatorName("!"), hasUnaryOperand(ignoringParens(
+                                         expr(equalsBoundNode("boolOp"))))))))
+          .bind("boolOp"),
+      this);
+}
+
+void MissingEndComparisonCheck::check(const MatchFinder::MatchResult &Result) {
+  const auto *BoolOp = Result.Nodes.getNodeAs<Expr>("boolOp");
+  assert(BoolOp);
+
+  std::optional<std::string> EndExprText;
+
+  if (Result.Nodes.getNodeAs<VarDecl>("cpo")) {
+    const auto *Call = Result.Nodes.getNodeAs<CallExpr>("algoCall");
+    EndExprText = getRangesEndText(*Result.Context, Call);
+  } else if (const auto *Call = Result.Nodes.getNodeAs<CallExpr>("algoCall")) {
+    EndExprText = getStandardEndText(*Result.Context, Call);
+  } else {
+    llvm_unreachable("Matcher should bind 'algoCall' or 'cpo'");
+  }
+
+  if (!EndExprText)
+    return;
+
+  const auto *NotOp = Result.Nodes.getNodeAs<UnaryOperator>("Neg");
+
+  auto Diag = diag(BoolOp->getBeginLoc(),
+                   "result of standard algorithm used as 'bool'; did you "
+                   "mean to compare with the end iterator?");
+
+  if (EndExprText->empty())
+    return;
+
+  // Suppress fix-it if the expression is part of a variable declaration or a
+  // condition variable declaration.
+  if (const auto *InitVar = Result.Nodes.getNodeAs<VarDecl>("initVar")) {
+    if (InitVar->getType()->isBooleanType())
+      return;
+
+    if (Result.Nodes.getNodeAs<Stmt>("condVarParent"))
+      return;
----------------
vbvictor wrote:

Maybe write as one `if` with && and || ?

https://github.com/llvm/llvm-project/pull/182543
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to