[clang-tools-extra] [clang-tidy] add check to suggest replacement of nested std::min or std::max with initializer lists (PR #85572)
Julian Schmidt via cfe-commits
cfe-commits at lists.llvm.org
Mon Apr 1 12:40:00 PDT 2024
================
@@ -0,0 +1,278 @@
+//===--- MinMaxUseInitializerListCheck.cpp - clang-tidy -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "MinMaxUseInitializerListCheck.h"
+#include "../utils/ASTUtils.h"
+#include "../utils/LexerUtils.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Lex/Lexer.h"
+
+using namespace clang;
+
+namespace {
+
+struct FindArgsResult {
+ const Expr *First;
+ const Expr *Last;
+ const Expr *Compare;
+ SmallVector<const clang::Expr *, 2> Args;
+};
+
+} // anonymous namespace
+
+using namespace clang::ast_matchers;
+
+namespace clang::tidy::modernize {
+
+static FindArgsResult findArgs(const CallExpr *Call) {
+ FindArgsResult Result;
+ Result.First = nullptr;
+ Result.Last = nullptr;
+ Result.Compare = nullptr;
+
+ // check if the function has initializer list argument
+ if (Call->getNumArgs() < 3) {
+ auto ArgIterator = Call->arguments().begin();
+
+ const auto *InitListExpr =
+ dyn_cast<CXXStdInitializerListExpr>(*ArgIterator);
+ const auto *InitList =
+ InitListExpr != nullptr
+ ? dyn_cast<clang::InitListExpr>(
+ InitListExpr->getSubExpr()->IgnoreImplicit())
+ : nullptr;
+
+ if (InitListExpr && InitList) {
+ Result.Args.insert(Result.Args.begin(), InitList->inits().begin(),
+ InitList->inits().end());
+ Result.First = *ArgIterator;
+ Result.Last = *ArgIterator;
+
+ // check if there is a comparison argument
+ std::advance(ArgIterator, 1);
+ if (ArgIterator != Call->arguments().end()) {
+ Result.Compare = *ArgIterator;
+ }
+
+ return Result;
+ }
+ // if it has 3 arguments then the last will be the comparison
+ } else {
+ Result.Compare = *(std::next(Call->arguments().begin(), 2));
+ }
+
+ for (const Expr *Arg : Call->arguments()) {
+ if (!Result.First)
+ Result.First = Arg;
+
+ if (Arg == Result.Compare)
+ continue;
+
+ Result.Args.push_back(Arg);
+ Result.Last = Arg;
+ }
+
+ return Result;
+}
+
+static SmallVector<FixItHint>
+generateReplacement(const MatchFinder::MatchResult &Match,
+ const CallExpr *TopCall, const FindArgsResult &Result) {
+ SmallVector<FixItHint> FixItHints;
+
+ const QualType ResultType = TopCall->getDirectCallee()
+ ->getReturnType()
+ .getNonReferenceType()
+ .getUnqualifiedType()
+ .getCanonicalType();
+ const auto &SourceMngr = *Match.SourceManager;
+ const auto LanguageOpts = Match.Context->getLangOpts();
+ const bool IsInitializerList = Result.First == Result.Last;
+
+ // add { and } if the top call doesn't have an initializer list arg
+ if (!IsInitializerList) {
+ FixItHints.push_back(
+ FixItHint::CreateInsertion(Result.First->getBeginLoc(), "{"));
+
+ if (Result.Compare)
+ FixItHints.push_back(FixItHint::CreateInsertion(
+ Lexer::getLocForEndOfToken(Result.Last->getEndLoc(), 0, SourceMngr,
+ LanguageOpts),
+ "}"));
+ else
+ FixItHints.push_back(
+ FixItHint::CreateInsertion(TopCall->getEndLoc(), "}"));
+ }
+
+ for (const Expr *Arg : Result.Args) {
+ const auto *InnerCall = dyn_cast<CallExpr>(Arg->IgnoreParenImpCasts());
+
+ // If the argument is not a nested call
+ if (!InnerCall) {
+ // check if typecast is required
+ const QualType ArgType = Arg->IgnoreParenImpCasts()
+ ->getType()
+ .getUnqualifiedType()
+ .getCanonicalType();
+
+ if (ArgType == ResultType)
+ continue;
+
+ const StringRef ArgText = Lexer::getSourceText(
+ CharSourceRange::getTokenRange(Arg->getSourceRange()), SourceMngr,
+ LanguageOpts);
+
+ Twine Replacement = llvm::Twine("static_cast<")
+ .concat(ResultType.getAsString(LanguageOpts))
+ .concat(">(")
+ .concat(ArgText)
+ .concat(")");
+
+ FixItHints.push_back(FixItHint::CreateReplacement(Arg->getSourceRange(),
+ Replacement.str()));
+
+ continue;
+ }
+
+ const auto InnerResult = findArgs(InnerCall);
+ const auto InnerReplacements =
+ generateReplacement(Match, InnerCall, InnerResult);
+ const bool IsInnerInitializerList = InnerResult.First == InnerResult.Last;
+
+ // if the nested call doesn't have arguments skip it
+ if (!InnerResult.First || !InnerResult.Last)
+ continue;
+
+ // if the nested call is not the same as the top call
+ if (InnerCall->getDirectCallee()->getQualifiedNameAsString() !=
+ TopCall->getDirectCallee()->getQualifiedNameAsString())
+ continue;
+
+ // if the nested call doesn't have the same compare function
+ if ((Result.Compare || InnerResult.Compare) &&
+ !utils::areStatementsIdentical(Result.Compare, InnerResult.Compare,
+ *Match.Context))
+ continue;
+
+ // remove the function call
+ FixItHints.push_back(
+ FixItHint::CreateRemoval(CharSourceRange::getTokenRange(
+ InnerCall->getCallee()->getSourceRange())));
+
+ // remove the parentheses
+ const auto LParen = utils::lexer::findNextTokenSkippingComments(
+ InnerCall->getCallee()->getEndLoc(), SourceMngr, LanguageOpts);
+ FixItHints.push_back(
+ FixItHint::CreateRemoval(SourceRange(LParen->getLocation())));
+ FixItHints.push_back(
+ FixItHint::CreateRemoval(SourceRange(InnerCall->getRParenLoc())));
+
+ // if the inner call has an initializer list arg
+ if (IsInnerInitializerList) {
+ // remove the initializer list braces
+ FixItHints.push_back(FixItHint::CreateRemoval(
+ CharSourceRange::getTokenRange(InnerResult.First->getBeginLoc())));
+ FixItHints.push_back(FixItHint::CreateRemoval(
+ CharSourceRange::getTokenRange(InnerResult.First->getEndLoc())));
+ }
+
+ FixItHints.insert(FixItHints.end(),
+ // ignore { and } insertions for the inner call if it does
+ // not have an initializer list arg
+ InnerReplacements.begin() + (!IsInnerInitializerList) * 2,
+ InnerReplacements.end());
+
+ if (InnerResult.Compare) {
+ // find the comma after the value arguments
+ const auto Comma = utils::lexer::findNextTokenSkippingComments(
+ InnerResult.Last->getEndLoc(), SourceMngr, LanguageOpts);
+
+ // if there are comments between the comma and the comparison
+ if (utils::lexer::getPreviousToken(InnerResult.Compare->getExprLoc(),
+ SourceMngr, LanguageOpts, false)
+ .getLocation() != Comma->getLocation()) {
+ // remove the comma and the comparison
+ FixItHints.push_back(
+ FixItHint::CreateRemoval(SourceRange(Comma->getLocation())));
+
+ FixItHints.push_back(
+ FixItHint::CreateRemoval(CharSourceRange::getTokenRange(
+ InnerResult.Compare->getSourceRange())));
+ } else
+ // remove everything after the last argument
+ FixItHints.push_back(
+ FixItHint::CreateRemoval(CharSourceRange::getTokenRange(
+ Comma->getLocation(), InnerResult.Compare->getEndLoc())));
+ }
+ }
+
+ return FixItHints;
+}
+
+MinMaxUseInitializerListCheck::MinMaxUseInitializerListCheck(
+ StringRef Name, ClangTidyContext *Context)
+ : ClangTidyCheck(Name, Context),
+ Inserter(Options.getLocalOrGlobal("IncludeStyle",
+ utils::IncludeSorter::IS_LLVM),
+ areDiagsSelfContained()) {}
+
+void MinMaxUseInitializerListCheck::storeOptions(
+ ClangTidyOptions::OptionMap &Opts) {
+ Options.store(Opts, "IncludeStyle", Inserter.getStyle());
+}
+
+void MinMaxUseInitializerListCheck::registerMatchers(MatchFinder *Finder) {
+ auto CreateMatcher = [](const StringRef FunctionName) {
+ auto FuncDecl = functionDecl(hasName(FunctionName));
+ auto Expression = callExpr(callee(FuncDecl));
+
+ return callExpr(callee(FuncDecl),
+ anyOf(hasArgument(0, Expression),
+ hasArgument(1, Expression),
+ hasArgument(0, cxxStdInitializerListExpr())),
+ unless(hasParent(Expression)))
+ .bind("topCall");
+ };
+
+ Finder->addMatcher(CreateMatcher("::std::max"), this);
+ Finder->addMatcher(CreateMatcher("::std::min"), this);
+}
+
+void MinMaxUseInitializerListCheck::registerPPCallbacks(
+ const SourceManager &SM, Preprocessor *PP, Preprocessor *ModuleExpanderPP) {
+ Inserter.registerPreprocessor(PP);
+}
+
+void MinMaxUseInitializerListCheck::check(
+ const MatchFinder::MatchResult &Match) {
+
+ const auto *TopCall = Match.Nodes.getNodeAs<CallExpr>("topCall");
+
+ const FindArgsResult Result = findArgs(TopCall);
+ const SmallVector<FixItHint> Replacement =
+ generateReplacement(Match, TopCall, Result);
+
+ if (Replacement.size() <= 2) {
+ return;
+ }
----------------
5chmidti wrote:
It's not immediately clear why `<= 2` is needed, please add a small comment. Also, a single statement if-branch should not have braces.
https://github.com/llvm/llvm-project/pull/85572
More information about the cfe-commits
mailing list