[clang] [Clang] Optimize -Wunsafe-buffer-usage. (PR #125492)

via cfe-commits cfe-commits at lists.llvm.org
Mon Feb 3 05:29:02 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Ivana Ivanovska (ivanaivanovska)

<details>
<summary>Changes</summary>

The Clang disgnostic `-Wunsafe-buffer-usage` was adding up to +15% compilation time when used. Profiling showed that most of the overhead comes from the use of ASTMatchers.

This change replaces the ASTMatcher infrastructure with simple matching functions and keeps the functionality unchanged. It reduces the overhead added by `-Wunsafe-buffer-usage` by 87.8%, leaving a negligible additional compilation time of 1.7% when the diagnostic is used.

**Old version without -Wunsafe-buffer-usage:**
```
$ hyperfine -i -w 1 --runs 5 '/tmp/old_clang -c -std=c++20 /tmp/preprocessed.cc -ferror-limit=20' 
Benchmark 1: /tmp/old_clang -c -std=c++20 /tmp/preprocessed.cc -ferror-limit=20
  Time (mean ± σ):     231.035 s ±  3.210 s    [User: 229.134 s, System: 1.704 s]
  Range (min … max):   228.751 s … 236.682 s    5 runs
```

**Old version with -Wunsafe-buffer-usage:**
```
$ hyperfine -i -w 1 --runs 10 '/tmp/old_clang -c -std=c++20 /tmp/preprocessed.cc -ferror-limit=20 -Wunsafe-buffer-usage'
Benchmark 1: /tmp/old_clang -c -std=c++20 /tmp/preprocessed.cc -ferror-limit=20 -Wunsafe-buffer-usage
  Time (mean ± σ):     263.840 s ±  0.854 s    [User: 262.043 s, System: 1.575 s]
  Range (min … max):   262.442 s … 265.142 s    10 runs
```

**New version with -Wunsafe-buffer-usage:**
```
$ hyperfine -i -w 1 --runs 10 '/tmp/new_clang -c -std=c++20 /tmp/preprocessed.cc -ferror-limit=20 -Wunsafe-buffer-usage'
Benchmark 1: /tmp/new_clang -c -std=c++20 /tmp/preprocessed.cc -ferror-limit=20 -Wunsafe-buffer-usage
  Time (mean ± σ):     235.169 s ±  1.408 s    [User: 233.406 s, System: 1.561 s]
  Range (min … max):   232.221 s … 236.792 s    10 runs
```


---

Patch is 82.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125492.diff


1 Files Affected:

- (modified) clang/lib/Analysis/UnsafeBufferUsage.cpp (+906-521) 


``````````diff
diff --git a/clang/lib/Analysis/UnsafeBufferUsage.cpp b/clang/lib/Analysis/UnsafeBufferUsage.cpp
index c064aa30e8aedc..4520d28d9e9452 100644
--- a/clang/lib/Analysis/UnsafeBufferUsage.cpp
+++ b/clang/lib/Analysis/UnsafeBufferUsage.cpp
@@ -8,30 +8,32 @@
 
 #include "clang/Analysis/Analyses/UnsafeBufferUsage.h"
 #include "clang/AST/ASTContext.h"
+#include "clang/AST/ASTTypeTraits.h"
 #include "clang/AST/Decl.h"
+#include "clang/AST/DeclCXX.h"
 #include "clang/AST/DynamicRecursiveASTVisitor.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/FormatString.h"
+#include "clang/AST/ParentMapContext.h"
 #include "clang/AST/Stmt.h"
 #include "clang/AST/StmtVisitor.h"
 #include "clang/AST/Type.h"
-#include "clang/ASTMatchers/ASTMatchFinder.h"
-#include "clang/ASTMatchers/ASTMatchers.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Lex/Lexer.h"
 #include "clang/Lex/Preprocessor.h"
 #include "llvm/ADT/APSInt.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Casting.h"
-#include <memory>
 #include <optional>
 #include <queue>
+#include <set>
 #include <sstream>
 
 using namespace llvm;
 using namespace clang;
-using namespace ast_matchers;
 
 #ifndef NDEBUG
 namespace {
@@ -68,7 +70,7 @@ static std::string getDREAncestorString(const DeclRefExpr *DRE,
 
     if (StParents.size() > 1)
       return "unavailable due to multiple parents";
-    if (StParents.size() == 0)
+    if (StParents.empty())
       break;
     St = StParents.begin()->get<Stmt>();
     if (St)
@@ -76,10 +78,39 @@ static std::string getDREAncestorString(const DeclRefExpr *DRE,
   } while (St);
   return SS.str();
 }
+
 } // namespace
 #endif /* NDEBUG */
 
-namespace clang::ast_matchers {
+namespace {
+// Using a custom matcher instead of ASTMatchers to achieve better performance.
+class FastMatcher {
+public:
+  virtual bool matches(const DynTypedNode &DynNode, ASTContext &Ctx,
+                       const UnsafeBufferUsageHandler &Handler) = 0;
+  virtual ~FastMatcher() = default;
+};
+
+class MatchResult {
+
+public:
+  template <typename T> const T *getNodeAs(StringRef ID) const {
+    auto It = Nodes.find(std::string(ID));
+    if (It == Nodes.end()) {
+      return nullptr;
+    }
+    return It->second.get<T>();
+  }
+
+  void addNode(StringRef ID, const DynTypedNode &Node) {
+    Nodes[std::string(ID)] = Node;
+  }
+
+private:
+  llvm::StringMap<DynTypedNode> Nodes;
+};
+} // namespace
+
 // A `RecursiveASTVisitor` that traverses all descendants of a given node "n"
 // except for those belonging to a different callable of "n".
 class MatchDescendantVisitor : public DynamicRecursiveASTVisitor {
@@ -87,13 +118,11 @@ class MatchDescendantVisitor : public DynamicRecursiveASTVisitor {
   // Creates an AST visitor that matches `Matcher` on all
   // descendants of a given node "n" except for the ones
   // belonging to a different callable of "n".
-  MatchDescendantVisitor(const internal::DynTypedMatcher *Matcher,
-                         internal::ASTMatchFinder *Finder,
-                         internal::BoundNodesTreeBuilder *Builder,
-                         internal::ASTMatchFinder::BindKind Bind,
+  MatchDescendantVisitor(FastMatcher &Matcher, bool FindAll,
                          const bool ignoreUnevaluatedContext)
-      : Matcher(Matcher), Finder(Finder), Builder(Builder), Bind(Bind),
-        Matches(false), ignoreUnevaluatedContext(ignoreUnevaluatedContext) {
+      : Matcher(&Matcher), FindAll(FindAll), Matches(false),
+        ignoreUnevaluatedContext(ignoreUnevaluatedContext),
+        ActiveASTContext(nullptr), Handler(nullptr) {
     ShouldVisitTemplateInstantiations = true;
     ShouldVisitImplicitCode = false; // TODO: let's ignore implicit code for now
   }
@@ -104,7 +133,6 @@ class MatchDescendantVisitor : public DynamicRecursiveASTVisitor {
     Matches = false;
     if (const Stmt *StmtNode = DynNode.get<Stmt>()) {
       TraverseStmt(const_cast<Stmt *>(StmtNode));
-      *Builder = ResultBindings;
       return Matches;
     }
     return false;
@@ -186,106 +214,212 @@ class MatchDescendantVisitor : public DynamicRecursiveASTVisitor {
     return DynamicRecursiveASTVisitor::TraverseStmt(Node);
   }
 
+  void setASTContext(ASTContext &Context) { ActiveASTContext = &Context; }
+
+  void setHandler(const UnsafeBufferUsageHandler &NewHandler) {
+    Handler = &NewHandler;
+  }
+
 private:
   // Sets 'Matched' to true if 'Matcher' matches 'Node'
   //
   // Returns 'true' if traversal should continue after this function
   // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
   template <typename T> bool match(const T &Node) {
-    internal::BoundNodesTreeBuilder RecursiveBuilder(*Builder);
-
-    if (Matcher->matches(DynTypedNode::create(Node), Finder,
-                         &RecursiveBuilder)) {
-      ResultBindings.addMatch(RecursiveBuilder);
+    if (Matcher->matches(DynTypedNode::create(Node), *ActiveASTContext,
+                         *Handler)) {
       Matches = true;
-      if (Bind != internal::ASTMatchFinder::BK_All)
+      if (!FindAll)
         return false; // Abort as soon as a match is found.
     }
     return true;
   }
 
-  const internal::DynTypedMatcher *const Matcher;
-  internal::ASTMatchFinder *const Finder;
-  internal::BoundNodesTreeBuilder *const Builder;
-  internal::BoundNodesTreeBuilder ResultBindings;
-  const internal::ASTMatchFinder::BindKind Bind;
+  FastMatcher *const Matcher;
+  // When true, finds all matches. When false, finds the first match and stops.
+  const bool FindAll;
   bool Matches;
   bool ignoreUnevaluatedContext;
+  ASTContext *ActiveASTContext;
+  const UnsafeBufferUsageHandler *Handler;
 };
 
 // Because we're dealing with raw pointers, let's define what we mean by that.
-static auto hasPointerType() {
-  return hasType(hasCanonicalType(pointerType()));
+static bool hasPointerType(const Expr &E) {
+  return isa<PointerType>(E.getType().getCanonicalType());
 }
 
-static auto hasArrayType() { return hasType(hasCanonicalType(arrayType())); }
-
-AST_MATCHER_P(Stmt, forEachDescendantEvaluatedStmt, internal::Matcher<Stmt>,
-              innerMatcher) {
-  const DynTypedMatcher &DTM = static_cast<DynTypedMatcher>(innerMatcher);
-
-  MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All,
-                                 true);
-  return Visitor.findMatch(DynTypedNode::create(Node));
+static bool hasArrayType(const Expr &E) {
+  return isa<ArrayType>(E.getType().getCanonicalType());
 }
 
-AST_MATCHER_P(Stmt, forEachDescendantStmt, internal::Matcher<Stmt>,
-              innerMatcher) {
-  const DynTypedMatcher &DTM = static_cast<DynTypedMatcher>(innerMatcher);
+static void
+forEachDescendantEvaluatedStmt(const Stmt *S, ASTContext &Ctx,
+                               const UnsafeBufferUsageHandler &Handler,
+                               FastMatcher &Matcher) {
+  MatchDescendantVisitor Visitor(Matcher, /* FindAll */ true,
+                                 /*ignoreUnevaluatedContext*/ true);
+  Visitor.setASTContext(Ctx);
+  Visitor.setHandler(Handler);
+  Visitor.findMatch(DynTypedNode::create(*S));
+}
 
-  MatchDescendantVisitor Visitor(&DTM, Finder, Builder, ASTMatchFinder::BK_All,
-                                 false);
-  return Visitor.findMatch(DynTypedNode::create(Node));
+static void forEachDescendantStmt(const Stmt *S, ASTContext &Ctx,
+                                  const UnsafeBufferUsageHandler &Handler,
+                                  FastMatcher &Matcher) {
+  MatchDescendantVisitor Visitor(Matcher, /* FindAll */ true,
+                                 /*ignoreUnevaluatedContext*/ false);
+  Visitor.setASTContext(Ctx);
+  Visitor.setHandler(Handler);
+  Visitor.findMatch(DynTypedNode::create(*S));
 }
 
 // Matches a `Stmt` node iff the node is in a safe-buffer opt-out region
-AST_MATCHER_P(Stmt, notInSafeBufferOptOut, const UnsafeBufferUsageHandler *,
-              Handler) {
+static bool notInSafeBufferOptOut(const Stmt &Node,
+                                  const UnsafeBufferUsageHandler *Handler) {
   return !Handler->isSafeBufferOptOut(Node.getBeginLoc());
 }
 
-AST_MATCHER_P(Stmt, ignoreUnsafeBufferInContainer,
-              const UnsafeBufferUsageHandler *, Handler) {
+static bool
+ignoreUnsafeBufferInContainer(const Stmt &Node,
+                              const UnsafeBufferUsageHandler *Handler) {
   return Handler->ignoreUnsafeBufferInContainer(Node.getBeginLoc());
 }
 
-AST_MATCHER_P(Stmt, ignoreUnsafeLibcCall, const UnsafeBufferUsageHandler *,
-              Handler) {
-  if (Finder->getASTContext().getLangOpts().CPlusPlus)
+static bool ignoreUnsafeLibcCall(const ASTContext &Ctx, const Stmt &Node,
+                                 const UnsafeBufferUsageHandler *Handler) {
+  if (Ctx.getLangOpts().CPlusPlus)
     return Handler->ignoreUnsafeBufferInLibcCall(Node.getBeginLoc());
   return true; /* Only warn about libc calls for C++ */
 }
 
-AST_MATCHER_P(CastExpr, castSubExpr, internal::Matcher<Expr>, innerMatcher) {
-  return innerMatcher.matches(*Node.getSubExpr(), Finder, Builder);
+// Finds any expression 'e' such that `OnResult`
+// matches 'e' and 'e' is in an Unspecified Lvalue Context.
+static void findStmtsInUnspecifiedLvalueContext(
+    const Stmt *S, const llvm::function_ref<void(const Expr *)> OnResult) {
+  if (const auto *CE = dyn_cast<ImplicitCastExpr>(S)) {
+    if (CE->getCastKind() != CastKind::CK_LValueToRValue)
+      return;
+    OnResult(CE->getSubExpr());
+  }
+  if (const auto *BO = dyn_cast<BinaryOperator>(S)) {
+    if (BO->getOpcode() != BO_Assign)
+      return;
+    OnResult(BO->getLHS());
+  }
 }
 
-// Matches a `UnaryOperator` whose operator is pre-increment:
-AST_MATCHER(UnaryOperator, isPreInc) {
-  return Node.getOpcode() == UnaryOperator::Opcode::UO_PreInc;
+/// Note: Copied and modified from ASTMatchers.
+/// Matches all arguments and their respective types for a \c CallExpr or
+/// \c CXXConstructExpr. It is very similar to \c forEachArgumentWithParam but
+/// it works on calls through function pointers as well.
+///
+/// The difference is, that function pointers do not provide access to a
+/// \c ParmVarDecl, but only the \c QualType for each argument.
+///
+/// Given
+/// \code
+///   void f(int i);
+///   int y;
+///   f(y);
+///   void (*f_ptr)(int) = f;
+///   f_ptr(y);
+/// \endcode
+/// callExpr(
+///   forEachArgumentWithParamType(
+///     declRefExpr(to(varDecl(hasName("y")))),
+///     qualType(isInteger()).bind("type)
+/// ))
+///   matches f(y) and f_ptr(y)
+/// with declRefExpr(...)
+///   matching int y
+/// and qualType(...)
+///   matching int
+static void forEachArgumentWithParamType(
+    const CallExpr &Node,
+    const llvm::function_ref<void(QualType /*Param*/, const Expr * /*Arg*/)>
+        OnParamAndArg) {
+  // The first argument of an overloaded member operator is the implicit object
+  // argument of the method which should not be matched against a parameter, so
+  // we skip over it here.
+  unsigned ArgIndex = 0;
+  if (const auto *CE = dyn_cast<CXXOperatorCallExpr>(&Node)) {
+    const auto *FD = CE->getDirectCallee();
+    if (FD) {
+      if (const auto *MD = dyn_cast<CXXMethodDecl>(FD);
+          MD && !MD->isExplicitObjectMemberFunction()) {
+        // This is an overloaded operator call.
+        // We need to skip the first argument, which is the implicit object
+        // argument of the method which should not be matched against a
+        // parameter.
+        ++ArgIndex;
+      }
+    }
+  }
+
+  const FunctionProtoType *FProto = nullptr;
+
+  if (const auto *Call = dyn_cast<CallExpr>(&Node)) {
+    if (const auto *Value =
+            dyn_cast_or_null<ValueDecl>(Call->getCalleeDecl())) {
+      QualType QT = Value->getType().getCanonicalType();
+
+      // This does not necessarily lead to a `FunctionProtoType`,
+      // e.g. K&R functions do not have a function prototype.
+      if (QT->isFunctionPointerType())
+        FProto = QT->getPointeeType()->getAs<FunctionProtoType>();
+
+      if (QT->isMemberFunctionPointerType()) {
+        const auto *MP = QT->getAs<MemberPointerType>();
+        assert(MP && "Must be member-pointer if its a memberfunctionpointer");
+        FProto = MP->getPointeeType()->getAs<FunctionProtoType>();
+        assert(FProto &&
+               "The call must have happened through a member function "
+               "pointer");
+      }
+    }
+  }
+
+  unsigned ParamIndex = 0;
+  unsigned NumArgs = Node.getNumArgs();
+  if (FProto && FProto->isVariadic())
+    NumArgs = std::min(NumArgs, FProto->getNumParams());
+
+  const auto GetParamType =
+      [&FProto, &Node](unsigned int ParamIndex) -> std::optional<QualType> {
+    // This test is cheaper compared to the big matcher in the next if.
+    // Therefore, please keep this order.
+    if (FProto && FProto->getNumParams() > ParamIndex) {
+      return FProto->getParamType(ParamIndex);
+    }
+    if (const auto *E = dyn_cast<Expr>(&Node)) {
+      if (const auto *CE = dyn_cast<CXXConstructExpr>(E)) {
+        if (const auto *Ctor = CE->getConstructor();
+            Ctor && Ctor->getNumParams() > ParamIndex) {
+          return CE->getArg(ParamIndex)->getType();
+        }
+      }
+      if (const auto *CE = dyn_cast<CallExpr>(E)) {
+        const auto *FD = CE->getDirectCallee();
+        if (FD && FD->getNumParams() > ParamIndex) {
+          return CE->getArg(ParamIndex)->getType();
+        }
+      }
+    }
+    return std::nullopt;
+  };
+
+  for (; ArgIndex < NumArgs; ++ArgIndex, ++ParamIndex) {
+    auto ParamType = GetParamType(ParamIndex);
+    if (ParamType)
+      OnParamAndArg(*ParamType, Node.getArg(ArgIndex)->IgnoreParenCasts());
+  }
 }
 
-// Returns a matcher that matches any expression 'e' such that `innerMatcher`
-// matches 'e' and 'e' is in an Unspecified Lvalue Context.
-static auto isInUnspecifiedLvalueContext(internal::Matcher<Expr> innerMatcher) {
-  // clang-format off
-  return
-    expr(anyOf(
-      implicitCastExpr(
-        hasCastKind(CastKind::CK_LValueToRValue),
-        castSubExpr(innerMatcher)),
-      binaryOperator(
-        hasAnyOperatorName("="),
-        hasLHS(innerMatcher)
-      )
-    ));
-  // clang-format on
-}
-
-// Returns a matcher that matches any expression `e` such that `InnerMatcher`
-// matches `e` and `e` is in an Unspecified Pointer Context (UPC).
-static internal::Matcher<Stmt>
-isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher) {
+// Finds any expression `e` such that `InnerMatcher` matches `e` and
+// `e` is in an Unspecified Pointer Context (UPC).
+static void findStmtsInUnspecifiedPointerContext(
+    const Stmt *S, llvm::function_ref<void(const Stmt *)> InnerMatcher) {
   // A UPC can be
   // 1. an argument of a function call (except the callee has [[unsafe_...]]
   //    attribute), or
@@ -294,45 +428,57 @@ isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher) {
   // 4. the operand of a pointer subtraction operation
   //    (i.e., computing the distance between two pointers); or ...
 
-  // clang-format off
-  auto CallArgMatcher = callExpr(
+  if (auto *CE = dyn_cast<CallExpr>(S)) {
+    if (const auto *FnDecl = CE->getDirectCallee();
+        FnDecl && FnDecl->hasAttr<UnsafeBufferUsageAttr>())
+      return;
     forEachArgumentWithParamType(
-      InnerMatcher,
-      isAnyPointer() /* array also decays to pointer type*/),
-    unless(callee(
-      functionDecl(hasAttr(attr::UnsafeBufferUsage)))));
-
-  auto CastOperandMatcher =
-      castExpr(anyOf(hasCastKind(CastKind::CK_PointerToIntegral),
-		     hasCastKind(CastKind::CK_PointerToBoolean)),
-	       castSubExpr(allOf(hasPointerType(), InnerMatcher)));
-
-  auto CompOperandMatcher =
-      binaryOperator(hasAnyOperatorName("!=", "==", "<", "<=", ">", ">="),
-                     eachOf(hasLHS(allOf(hasPointerType(), InnerMatcher)),
-                            hasRHS(allOf(hasPointerType(), InnerMatcher))));
-
-  // A matcher that matches pointer subtractions:
-  auto PtrSubtractionMatcher =
-      binaryOperator(hasOperatorName("-"),
-		     // Note that here we need both LHS and RHS to be
-		     // pointer. Then the inner matcher can match any of
-		     // them:
-		     allOf(hasLHS(hasPointerType()),
-			   hasRHS(hasPointerType())),
-		     eachOf(hasLHS(InnerMatcher),
-			    hasRHS(InnerMatcher)));
-  // clang-format on
-
-  return stmt(anyOf(CallArgMatcher, CastOperandMatcher, CompOperandMatcher,
-                    PtrSubtractionMatcher));
-  // FIXME: any more cases? (UPC excludes the RHS of an assignment.  For now we
-  // don't have to check that.)
-}
-
-// Returns a matcher that matches any expression 'e' such that `innerMatcher`
-// matches 'e' and 'e' is in an unspecified untyped context (i.e the expression
-// 'e' isn't evaluated to an RValue). For example, consider the following code:
+        *CE, [&InnerMatcher](QualType Type, const Expr *Arg) {
+          if (Type->isAnyPointerType())
+            InnerMatcher(Arg);
+        });
+  }
+
+  if (auto *CE = dyn_cast<CastExpr>(S)) {
+    if (CE->getCastKind() != CastKind::CK_PointerToIntegral &&
+        CE->getCastKind() != CastKind::CK_PointerToBoolean)
+      return;
+    if (!hasPointerType(*CE->getSubExpr()))
+      return;
+    InnerMatcher(CE->getSubExpr());
+  }
+
+  // Pointer comparison operator.
+  if (const auto *BO = dyn_cast<BinaryOperator>(S);
+      BO && (BO->getOpcode() == BO_EQ || BO->getOpcode() == BO_NE ||
+             BO->getOpcode() == BO_LT || BO->getOpcode() == BO_LE ||
+             BO->getOpcode() == BO_GT || BO->getOpcode() == BO_GE)) {
+    auto *LHS = BO->getLHS();
+    auto *RHS = BO->getRHS();
+    if (!hasPointerType(*LHS) || !hasPointerType(*RHS))
+      return;
+    InnerMatcher(LHS);
+    InnerMatcher(RHS);
+  }
+
+  // Pointer subtractions.
+  if (const auto *BO = dyn_cast<BinaryOperator>(S);
+      BO && BO->getOpcode() == BO_Sub && hasPointerType(*BO->getLHS()) &&
+      hasPointerType(*BO->getRHS())) {
+    // Note that here we need both LHS and RHS to be
+    // pointer. Then the inner matcher can match any of
+    // them:
+    InnerMatcher(BO->getLHS());
+    InnerMatcher(BO->getRHS());
+  }
+  // FIXME: any more cases? (UPC excludes the RHS of an assignment.  For now
+  // we don't have to check that.)
+}
+
+// Finds statements in unspecified untyped context i.e. any expression 'e' such
+// that `InnerMatcher` matches 'e' and 'e' is in an unspecified untyped context
+// (i.e the expression 'e' isn't evaluated to an RValue). For example, consider
+// the following code:
 //    int *p = new int[4];
 //    int *q = new int[4];
 //    if ((p = q)) {}
@@ -340,17 +486,23 @@ isInUnspecifiedPointerContext(internal::Matcher<Stmt> InnerMatcher) {
 // The expression `p = q` in the conditional of the `if` statement
 // `if ((p = q))` is evaluated as an RValue, whereas the expression `p = q;`
 // in the assignment statement is in an untyped context.
-static internal::Matcher<Stmt>
-isInUnspecifiedUntypedContext(internal::Matcher<Stmt> InnerMatcher) {
+static void findStmtsInUnspecifiedUntypedContext(
+    const Stmt *S, llvm::function_ref<void(const Stmt *)> InnerMatcher) {
   // An unspecified context can be
   // 1. A compound statement,
   // 2. The body of an if statement
   // 3. Body of a loop
-  auto CompStmt = compoundStmt(forEach(InnerMatcher));
-  auto IfStmtThen = ifStmt(hasThen(InnerMatcher));
-  auto IfStmtElse = ifStmt(hasElse(InnerMatcher));
+  if (auto *CS = dyn_cast<CompoundStmt>(S)) {
+    for (auto *Child : CS->body())
+      InnerMatcher(Child);
+  }
+  if (auto *IfS = dyn_cast<IfStmt>(S)) {
+    if (IfS->getThen())
+      InnerMatcher(IfS->getThen());
+    if (IfS->getElse())
+      InnerMatcher(IfS->getElse());
+  }
   // FIXME: Handle loop bodies.
-  return stmt(anyOf(CompStmt, IfStmtThen, IfStmtElse));
 }
 
 // Given a two-param std::span construct call, matches iff the call has the
@@ -362,14 +514,15 @@ isInUnspecifiedUntypedContext(internal::Matcher<Stmt> InnerMatcher) {
 //   `n`
 //   5. `std::span<T>{any, 0}`
 //   6. `std::span<T>{std::addressof(...), 1}`
-AST_MATCHER(CXXConstructExpr, isSafeSpanTwoParamConstruct) {
+static bool isSafeSpanTwoParamConstruct(const CXXConstructExpr &Node,
+                                        const ASTContext &Ctx) {
   assert(Node.getNumArgs() == 2 &&
          "expec...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/125492


More information about the cfe-commits mailing list