[clang] [alpha.webkit.UncountedLocalVarsChecker] Allow uncounted object references within trivial statements (PR #82229)

Ryosuke Niwa via cfe-commits cfe-commits at lists.llvm.org
Fri Mar 1 16:49:32 PST 2024


https://github.com/rniwa updated https://github.com/llvm/llvm-project/pull/82229

>From 234e301ab2721ddb2f4b43589785015a7d0aa304 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rniwa at webkit.org>
Date: Mon, 19 Feb 2024 01:07:13 -0800
Subject: [PATCH 1/5] [alpha.webkit.UncountedLocalVarsChecker] Allow uncounted
 object references within trivial statements

This PR makes alpha.webkit.UncountedLocalVarsChecker ignore raw references and pointers to
a ref counted type which appears within "trival" statements. To do this, this PR extends
TrivialFunctionAnalysis so that it can also analyze "triviality" of statements as well as
that of functions Each Visit* function is now augmented with withCachedResult, which is
responsible for looking up and updating the cache for each Visit* functions.

As this PR dramatically improves the false positive rate of the checker, it also deletes
the code to ignore raw pointers and references within if and for statements.
---
 .../Checkers/WebKit/PtrTypesSemantics.cpp     | 222 ++++++++++++------
 .../Checkers/WebKit/PtrTypesSemantics.h       |  21 +-
 .../WebKit/UncountedLocalVarsChecker.cpp      |  69 +++---
 .../Analysis/Checkers/WebKit/mock-types.h     |   2 +
 .../Checkers/WebKit/uncounted-local-vars.cpp  |  92 +++++++-
 5 files changed, 285 insertions(+), 121 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 01b191ab0eeaf4..6c9a8aedb38a4c 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -245,18 +245,41 @@ class TrivialFunctionAnalysisVisitor
 
   // Returns false if at least one child is non-trivial.
   bool VisitChildren(const Stmt *S) {
-    for (const Stmt *Child : S->children()) {
-      if (Child && !Visit(Child))
+    return withCachedResult(S, [&]() {
+      for (const Stmt *Child : S->children()) {
+        if (Child && !Visit(Child))
+          return false;
+      }
+      return true;
+    });
+  }
+
+  bool VisitSubExpr(const Expr *Parent, const Expr *E) {
+    return withCachedResult(Parent, [&]() {
+      if (!Visit(E))
         return false;
-    }
+      return true;
+    });
+  }
 
-    return true;
+  template <typename StmtType, typename CheckFunction>
+  bool withCachedResult(const StmtType *S, CheckFunction Function) {
+    // Insert false to the cache first to avoid infinite recursion.
+    auto [It, IsNew] = StatementCache.insert(std::make_pair(S, false));
+    if (!IsNew)
+      return It->second;
+    bool Result = Function();
+    It->second = Result;
+    return Result;
   }
 
 public:
-  using CacheTy = TrivialFunctionAnalysis::CacheTy;
+  using FunctionCacheTy = TrivialFunctionAnalysis::FunctionCacheTy;
+  using StatementCacheTy = TrivialFunctionAnalysis::StatementCacheTy;
 
-  TrivialFunctionAnalysisVisitor(CacheTy &Cache) : Cache(Cache) {}
+  TrivialFunctionAnalysisVisitor(FunctionCacheTy &FunctionCache,
+                                 StatementCacheTy &StatementCache)
+      : FunctionCache(FunctionCache), StatementCache(StatementCache) {}
 
   bool VisitStmt(const Stmt *S) {
     // All statements are non-trivial unless overriden later.
@@ -272,13 +295,21 @@ class TrivialFunctionAnalysisVisitor
 
   bool VisitReturnStmt(const ReturnStmt *RS) {
     // A return statement is allowed as long as the return value is trivial.
-    if (auto *RV = RS->getRetValue())
-      return Visit(RV);
-    return true;
+    return withCachedResult(RS, [&]() {
+      if (auto *RV = RS->getRetValue())
+        return Visit(RV);
+      return true;
+    });
+  }
+
+  bool VisitCXXForRangeStmt(const CXXForRangeStmt *FS) {
+    return VisitChildren(FS);
   }
 
   bool VisitDeclStmt(const DeclStmt *DS) { return VisitChildren(DS); }
   bool VisitDoStmt(const DoStmt *DS) { return VisitChildren(DS); }
+  bool VisitForStmt(const ForStmt *FS) { return VisitChildren(FS); }
+  bool VisitWhileStmt(const WhileStmt *WS) { return VisitChildren(WS); }
   bool VisitIfStmt(const IfStmt *IS) { return VisitChildren(IS); }
   bool VisitSwitchStmt(const SwitchStmt *SS) { return VisitChildren(SS); }
   bool VisitCaseStmt(const CaseStmt *CS) { return VisitChildren(CS); }
@@ -286,17 +317,26 @@ class TrivialFunctionAnalysisVisitor
 
   bool VisitUnaryOperator(const UnaryOperator *UO) {
     // Operator '*' and '!' are allowed as long as the operand is trivial.
-    if (UO->getOpcode() == UO_Deref || UO->getOpcode() == UO_AddrOf ||
-        UO->getOpcode() == UO_LNot)
-      return Visit(UO->getSubExpr());
-
-    // Other operators are non-trivial.
-    return false;
+    return withCachedResult(UO, [&]() {
+      auto op = UO->getOpcode();
+      if (op == UO_Deref || op == UO_AddrOf || op == UO_LNot)
+        return Visit(UO->getSubExpr());
+      if (UO->isIncrementOp() || UO->isDecrementOp()) {
+        if (auto *RefExpr = dyn_cast<DeclRefExpr>(UO->getSubExpr())) {
+          if (auto *Decl = dyn_cast<VarDecl>(RefExpr->getDecl()))
+            return Decl->isLocalVarDeclOrParm() &&
+                   Decl->getType().isPODType(Decl->getASTContext());
+        }
+      }
+      // Other operators are non-trivial.
+      return false;
+    });
   }
 
   bool VisitBinaryOperator(const BinaryOperator *BO) {
     // Binary operators are trivial if their operands are trivial.
-    return Visit(BO->getLHS()) && Visit(BO->getRHS());
+    return withCachedResult(
+        BO, [&]() { return Visit(BO->getLHS()) && Visit(BO->getRHS()); });
   }
 
   bool VisitConditionalOperator(const ConditionalOperator *CO) {
@@ -305,19 +345,21 @@ class TrivialFunctionAnalysisVisitor
   }
 
   bool VisitDeclRefExpr(const DeclRefExpr *DRE) {
-    if (auto *decl = DRE->getDecl()) {
-      if (isa<ParmVarDecl>(decl))
-        return true;
-      if (isa<EnumConstantDecl>(decl))
-        return true;
-      if (auto *VD = dyn_cast<VarDecl>(decl)) {
-        if (VD->hasConstantInitialization() && VD->getEvaluatedValue())
+    return withCachedResult(DRE, [&]() {
+      if (auto *decl = DRE->getDecl()) {
+        if (isa<ParmVarDecl>(decl))
+          return true;
+        if (isa<EnumConstantDecl>(decl))
           return true;
-        auto *Init = VD->getInit();
-        return !Init || Visit(Init);
+        if (auto *VD = dyn_cast<VarDecl>(decl)) {
+          if (VD->hasConstantInitialization() && VD->getEvaluatedValue())
+            return true;
+          auto *Init = VD->getInit();
+          return !Init || Visit(Init);
+        }
       }
-    }
-    return false;
+      return false;
+    });
   }
 
   bool VisitAtomicExpr(const AtomicExpr *E) { return VisitChildren(E); }
@@ -328,20 +370,23 @@ class TrivialFunctionAnalysisVisitor
   }
 
   bool VisitCallExpr(const CallExpr *CE) {
-    if (!checkArguments(CE))
-      return false;
+    return withCachedResult(CE, [&]() {
+      if (!checkArguments(CE))
+        return false;
 
-    auto *Callee = CE->getDirectCallee();
-    if (!Callee)
-      return false;
-    const auto &Name = safeGetName(Callee);
+      auto *Callee = CE->getDirectCallee();
+      if (!Callee)
+        return false;
+      const auto &Name = safeGetName(Callee);
 
-    if (Name == "WTFCrashWithInfo" || Name == "WTFBreakpointTrap" ||
-        Name == "WTFReportAssertionFailure" ||
-        Name == "compilerFenceForCrash" || Name == "__builtin_unreachable")
-      return true;
+      if (Name == "WTFCrashWithInfo" || Name == "WTFBreakpointTrap" ||
+          Name == "WTFReportAssertionFailure" ||
+          Name == "compilerFenceForCrash" || Name == "__builtin_unreachable")
+        return true;
 
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
+      return TrivialFunctionAnalysis::isTrivialImpl(Callee, FunctionCache,
+                                                    StatementCache);
+    });
   }
 
   bool VisitPredefinedExpr(const PredefinedExpr *E) {
@@ -350,23 +395,26 @@ class TrivialFunctionAnalysisVisitor
   }
 
   bool VisitCXXMemberCallExpr(const CXXMemberCallExpr *MCE) {
-    if (!checkArguments(MCE))
-      return false;
+    return withCachedResult(MCE, [&]() {
+      if (!checkArguments(MCE))
+        return false;
 
-    bool TrivialThis = Visit(MCE->getImplicitObjectArgument());
-    if (!TrivialThis)
-      return false;
+      bool TrivialThis = Visit(MCE->getImplicitObjectArgument());
+      if (!TrivialThis)
+        return false;
 
-    auto *Callee = MCE->getMethodDecl();
-    if (!Callee)
-      return false;
+      auto *Callee = MCE->getMethodDecl();
+      if (!Callee)
+        return false;
 
-    std::optional<bool> IsGetterOfRefCounted = isGetterOfRefCounted(Callee);
-    if (IsGetterOfRefCounted && *IsGetterOfRefCounted)
-      return true;
+      std::optional<bool> IsGetterOfRefCounted = isGetterOfRefCounted(Callee);
+      if (IsGetterOfRefCounted && *IsGetterOfRefCounted)
+        return true;
 
-    // Recursively descend into the callee to confirm that it's trivial as well.
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
+      // Recursively descend into the callee to confirm it's trivial as well.
+      return TrivialFunctionAnalysis::isTrivialImpl(Callee, FunctionCache,
+                                                    StatementCache);
+    });
   }
 
   bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) {
@@ -386,44 +434,51 @@ class TrivialFunctionAnalysisVisitor
   }
 
   bool VisitCXXConstructExpr(const CXXConstructExpr *CE) {
-    for (const Expr *Arg : CE->arguments()) {
-      if (Arg && !Visit(Arg))
-        return false;
-    }
+    return withCachedResult(CE, [&]() {
+      for (const Expr *Arg : CE->arguments()) {
+        if (Arg && !Visit(Arg))
+          return false;
+      }
 
-    // Recursively descend into the callee to confirm that it's trivial.
-    return TrivialFunctionAnalysis::isTrivialImpl(CE->getConstructor(), Cache);
+      // Recursively descend into the callee to confirm that it's trivial.
+      return TrivialFunctionAnalysis::isTrivialImpl(
+          CE->getConstructor(), FunctionCache, StatementCache);
+    });
   }
 
   bool VisitImplicitCastExpr(const ImplicitCastExpr *ICE) {
-    return Visit(ICE->getSubExpr());
+    return VisitSubExpr(ICE, ICE->getSubExpr());
   }
 
   bool VisitExplicitCastExpr(const ExplicitCastExpr *ECE) {
-    return Visit(ECE->getSubExpr());
+    return VisitSubExpr(ECE, ECE->getSubExpr());
   }
 
   bool VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *VMT) {
-    return Visit(VMT->getSubExpr());
+    return VisitSubExpr(VMT, VMT->getSubExpr());
   }
 
   bool VisitExprWithCleanups(const ExprWithCleanups *EWC) {
-    return Visit(EWC->getSubExpr());
+    return VisitSubExpr(EWC, EWC->getSubExpr());
   }
 
-  bool VisitParenExpr(const ParenExpr *PE) { return Visit(PE->getSubExpr()); }
+  bool VisitParenExpr(const ParenExpr *PE) {
+    return VisitSubExpr(PE, PE->getSubExpr());
+  }
 
   bool VisitInitListExpr(const InitListExpr *ILE) {
-    for (const Expr *Child : ILE->inits()) {
-      if (Child && !Visit(Child))
-        return false;
-    }
-    return true;
+    return withCachedResult(ILE, [&]() {
+      for (const Expr *Child : ILE->inits()) {
+        if (Child && !Visit(Child))
+          return false;
+      }
+      return true;
+    });
   }
 
   bool VisitMemberExpr(const MemberExpr *ME) {
     // Field access is allowed but the base pointer may itself be non-trivial.
-    return Visit(ME->getBase());
+    return VisitSubExpr(ME, ME->getBase());
   }
 
   bool VisitCXXThisExpr(const CXXThisExpr *CTE) {
@@ -449,16 +504,18 @@ class TrivialFunctionAnalysisVisitor
   }
 
 private:
-  CacheTy Cache;
+  FunctionCacheTy FunctionCache;
+  StatementCacheTy StatementCache;
 };
 
 bool TrivialFunctionAnalysis::isTrivialImpl(
-    const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache) {
+    const Decl *D, TrivialFunctionAnalysis::FunctionCacheTy &FunctionCache,
+    TrivialFunctionAnalysis::StatementCacheTy &StatementCache) {
   // If the function isn't in the cache, conservatively assume that
   // it's not trivial until analysis completes. This makes every recursive
   // function non-trivial. This also guarantees that each function
   // will be scanned at most once.
-  auto [It, IsNew] = Cache.insert(std::make_pair(D, false));
+  auto [It, IsNew] = FunctionCache.insert(std::make_pair(D, false));
   if (!IsNew)
     return It->second;
 
@@ -466,10 +523,29 @@ bool TrivialFunctionAnalysis::isTrivialImpl(
   if (!Body)
     return false;
 
-  TrivialFunctionAnalysisVisitor V(Cache);
+  TrivialFunctionAnalysisVisitor V(FunctionCache, StatementCache);
   bool Result = V.Visit(Body);
   if (Result)
-    Cache[D] = true;
+    FunctionCache[D] = true;
+
+  return Result;
+}
+
+bool TrivialFunctionAnalysis::isTrivialImpl(
+    const Stmt *S, TrivialFunctionAnalysis::FunctionCacheTy &FunctionCache,
+    TrivialFunctionAnalysis::StatementCacheTy &StatementCache) {
+  // If the statement isn't in the cache, conservatively assume that
+  // it's not trivial until analysis completes. Unlike a function case,
+  // we don't insert an entry into the cache until Visit returns
+  // since Visit* functions themselves make use of the cache.
+
+  auto It = StatementCache.find(S);
+  if (It != StatementCache.end())
+    return It->second;
+
+  TrivialFunctionAnalysisVisitor V(FunctionCache, StatementCache);
+  bool Result = V.Visit(S);
+  StatementCache[S] = Result;
 
   return Result;
 }
diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
index e07cd31395747d..3f4cdd1f2ffb02 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
@@ -19,6 +19,7 @@ class CXXMethodDecl;
 class CXXRecordDecl;
 class Decl;
 class FunctionDecl;
+class Stmt;
 class Type;
 
 // Ref-countability of a type is implicitly defined by Ref<T> and RefPtr<T>
@@ -70,15 +71,27 @@ bool isSingleton(const FunctionDecl *F);
 class TrivialFunctionAnalysis {
 public:
   /// \returns true if \p D is a "trivial" function.
-  bool isTrivial(const Decl *D) const { return isTrivialImpl(D, TheCache); }
+  bool isTrivial(const Decl *D) const {
+    return isTrivialImpl(D, TheFunctionCache, TheStatementCache);
+  }
+
+  bool isTrivial(const Stmt *S) const {
+    return isTrivialImpl(S, TheFunctionCache, TheStatementCache);
+  }
 
 private:
   friend class TrivialFunctionAnalysisVisitor;
 
-  using CacheTy = llvm::DenseMap<const Decl *, bool>;
-  mutable CacheTy TheCache{};
+  using FunctionCacheTy = llvm::DenseMap<const Decl *, bool>;
+  mutable FunctionCacheTy TheFunctionCache{};
+
+  using StatementCacheTy = llvm::DenseMap<const Stmt *, bool>;
+  mutable StatementCacheTy TheStatementCache{};
 
-  static bool isTrivialImpl(const Decl *D, CacheTy &Cache);
+  static bool isTrivialImpl(const Decl *D, FunctionCacheTy &FunctionCache,
+                            StatementCacheTy &StatementCache);
+  static bool isTrivialImpl(const Stmt *S, FunctionCacheTy &FunctionCache,
+                            StatementCacheTy &StatementCache);
 };
 
 } // namespace clang
diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
index 5a72f53b12edaa..4068b472cc5fcd 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
@@ -26,28 +26,6 @@ using namespace ento;
 
 namespace {
 
-// for ( int a = ...) ... true
-// for ( int a : ...) ... true
-// if ( int* a = ) ... true
-// anything else ... false
-bool isDeclaredInForOrIf(const VarDecl *Var) {
-  assert(Var);
-  auto &ASTCtx = Var->getASTContext();
-  auto parent = ASTCtx.getParents(*Var);
-
-  if (parent.size() == 1) {
-    if (auto *DS = parent.begin()->get<DeclStmt>()) {
-      DynTypedNodeList grandParent = ASTCtx.getParents(*DS);
-      if (grandParent.size() == 1) {
-        return grandParent.begin()->get<ForStmt>() ||
-               grandParent.begin()->get<IfStmt>() ||
-               grandParent.begin()->get<CXXForRangeStmt>();
-      }
-    }
-  }
-  return false;
-}
-
 // FIXME: should be defined by anotations in the future
 bool isRefcountedStringsHack(const VarDecl *V) {
   assert(V);
@@ -133,6 +111,8 @@ class UncountedLocalVarsChecker
               "WebKit coding guidelines"};
   mutable BugReporter *BR;
 
+  TrivialFunctionAnalysis TFA;
+
 public:
   void checkASTDecl(const TranslationUnitDecl *TUD, AnalysisManager &MGR,
                     BugReporter &BRArg) const {
@@ -171,6 +151,24 @@ class UncountedLocalVarsChecker
 
     std::optional<bool> IsUncountedPtr = isUncountedPtr(ArgType);
     if (IsUncountedPtr && *IsUncountedPtr) {
+
+      ASTContext &ctx = V->getASTContext();
+      for (DynTypedNodeList ancestors = ctx.getParents(*V); !ancestors.empty();
+           ancestors = ctx.getParents(*ancestors.begin())) {
+        for (auto &ancestor : ancestors) {
+          if (auto *S = ancestor.get<IfStmt>(); S && TFA.isTrivial(S))
+            return;
+          if (auto *S = ancestor.get<ForStmt>(); S && TFA.isTrivial(S))
+            return;
+          if (auto *S = ancestor.get<CXXForRangeStmt>(); S && TFA.isTrivial(S))
+            return;
+          if (auto *S = ancestor.get<WhileStmt>(); S && TFA.isTrivial(S))
+            return;
+          if (auto *S = ancestor.get<CompoundStmt>(); S && TFA.isTrivial(S))
+            return;
+        }
+      }
+
       const Expr *const InitExpr = V->getInit();
       if (!InitExpr)
         return; // FIXME: later on we might warn on uninitialized vars too
@@ -178,6 +176,7 @@ class UncountedLocalVarsChecker
       const clang::Expr *const InitArgOrigin =
           tryToFindPtrOrigin(InitExpr, /*StopAtFirstRefCountedObj=*/false)
               .first;
+
       if (!InitArgOrigin)
         return;
 
@@ -189,20 +188,17 @@ class UncountedLocalVarsChecker
                 dyn_cast_or_null<VarDecl>(Ref->getFoundDecl())) {
           const auto *MaybeGuardianArgType =
               MaybeGuardian->getType().getTypePtr();
-          if (!MaybeGuardianArgType)
-            return;
-          const CXXRecordDecl *const MaybeGuardianArgCXXRecord =
-              MaybeGuardianArgType->getAsCXXRecordDecl();
-          if (!MaybeGuardianArgCXXRecord)
-            return;
-
-          if (MaybeGuardian->isLocalVarDecl() &&
-              (isRefCounted(MaybeGuardianArgCXXRecord) ||
-               isRefcountedStringsHack(MaybeGuardian)) &&
-              isGuardedScopeEmbeddedInGuardianScope(V, MaybeGuardian)) {
-            return;
+          if (MaybeGuardianArgType) {
+            const CXXRecordDecl *const MaybeGuardianArgCXXRecord =
+                MaybeGuardianArgType->getAsCXXRecordDecl();
+            if (MaybeGuardianArgCXXRecord) {
+              if (MaybeGuardian->isLocalVarDecl() &&
+                  (isRefCounted(MaybeGuardianArgCXXRecord) ||
+                   isRefcountedStringsHack(MaybeGuardian)) &&
+                  isGuardedScopeEmbeddedInGuardianScope(V, MaybeGuardian))
+                return;
+            }
           }
-
           // Parameters are guaranteed to be safe for the duration of the call
           // by another checker.
           if (isa<ParmVarDecl>(MaybeGuardian))
@@ -219,9 +215,6 @@ class UncountedLocalVarsChecker
     if (!V->isLocalVarDecl())
       return true;
 
-    if (isDeclaredInForOrIf(V))
-      return true;
-
     return false;
   }
 
diff --git a/clang/test/Analysis/Checkers/WebKit/mock-types.h b/clang/test/Analysis/Checkers/WebKit/mock-types.h
index e2b3401d407392..aab99197dfa49e 100644
--- a/clang/test/Analysis/Checkers/WebKit/mock-types.h
+++ b/clang/test/Analysis/Checkers/WebKit/mock-types.h
@@ -62,6 +62,8 @@ struct RefCountable {
   static Ref<RefCountable> create();
   void ref() {}
   void deref() {}
+  void method();
+  int trivial() { return 123; }
 };
 
 template <typename T> T *downcast(T *t) { return t; }
diff --git a/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp b/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp
index 0fcd3b21376caf..3fe04f775fbbcb 100644
--- a/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp
@@ -2,6 +2,8 @@
 
 #include "mock-types.h"
 
+void someFunction();
+
 namespace raw_ptr {
 void foo() {
   RefCountable *bar;
@@ -16,6 +18,13 @@ void foo_ref() {
   RefCountable automatic;
   RefCountable &bar = automatic;
   // expected-warning at -1{{Local variable 'bar' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
+  someFunction();
+  bar.method();
+}
+
+void foo_ref_trivial() {
+  RefCountable automatic;
+  RefCountable &bar = automatic;
 }
 
 void bar_ref(RefCountable &) {}
@@ -32,6 +41,8 @@ void foo2() {
   // missing embedded scope here
   RefCountable *bar = foo.get();
   // expected-warning at -1{{Local variable 'bar' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
+  someFunction();
+  bar->method();
 }
 
 void foo3() {
@@ -47,11 +58,25 @@ void foo4() {
     { RefCountable *bar = foo.get(); }
   }
 }
+
+void foo5() {
+  RefPtr<RefCountable> foo;
+  auto* bar = foo.get();
+  bar->trivial();
+}
+
+void foo6() {
+  RefPtr<RefCountable> foo;
+  auto* bar = foo.get();
+  // expected-warning at -1{{Local variable 'bar' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
+  bar->method();
+}
+
 } // namespace guardian_scopes
 
 namespace auto_keyword {
 class Foo {
-  RefCountable *provide_ref_ctnbl() { return nullptr; }
+  RefCountable *provide_ref_ctnbl();
 
   void evil_func() {
     RefCountable *bar = provide_ref_ctnbl();
@@ -62,13 +87,24 @@ class Foo {
     // expected-warning at -1{{Local variable 'baz2' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
     [[clang::suppress]] auto *baz_suppressed = provide_ref_ctnbl(); // no-warning
   }
+
+  void func() {
+    RefCountable *bar = provide_ref_ctnbl();
+    // expected-warning at -1{{Local variable 'bar' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
+    if (bar)
+      bar->method();
+  }
 };
 } // namespace auto_keyword
 
 namespace guardian_casts {
 void foo1() {
   RefPtr<RefCountable> foo;
-  { RefCountable *bar = downcast<RefCountable>(foo.get()); }
+  {
+    RefCountable *bar = downcast<RefCountable>(foo.get());
+    bar->method();
+  }
+  foo->method();
 }
 
 void foo2() {
@@ -76,6 +112,7 @@ void foo2() {
   {
     RefCountable *bar =
         static_cast<RefCountable *>(downcast<RefCountable>(foo.get()));
+    someFunction();
   }
 }
 } // namespace guardian_casts
@@ -83,7 +120,11 @@ void foo2() {
 namespace guardian_ref_conversion_operator {
 void foo() {
   Ref<RefCountable> rc;
-  { RefCountable &rr = rc; }
+  {
+    RefCountable &rr = rc;
+    rr.method();
+    someFunction();
+  }
 }
 } // namespace guardian_ref_conversion_operator
 
@@ -92,9 +133,48 @@ RefCountable *provide_ref_ctnbl() { return nullptr; }
 
 void foo() {
   // no warnings
-  if (RefCountable *a = provide_ref_ctnbl()) { }
-  for (RefCountable *a = provide_ref_ctnbl(); a != nullptr;) { }
+  if (RefCountable *a = provide_ref_ctnbl())
+    a->trivial();
+  for (RefCountable *b = provide_ref_ctnbl(); b != nullptr;)
+    b->trivial();
   RefCountable *array[1];
-  for (RefCountable *a : array) { }
+  for (RefCountable *c : array)
+    c->trivial();
+  while (RefCountable *d = provide_ref_ctnbl())
+    d->trivial();
+  do {
+    RefCountable *e = provide_ref_ctnbl();
+    e->trivial();
+  } while (1);
+  someFunction();
 }
+
+void bar() {
+  // no warnings
+  if (RefCountable *a = provide_ref_ctnbl()) {
+    // expected-warning at -1{{Local variable 'a' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
+    a->method();    
+  }
+  for (RefCountable *b = provide_ref_ctnbl(); b != nullptr;) {
+    // expected-warning at -1{{Local variable 'b' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
+    b->method();
+  }
+  RefCountable *array[1];
+  for (RefCountable *c : array) {
+    // expected-warning at -1{{Local variable 'c' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
+    c->method();
+  }
+
+  while (RefCountable *d = provide_ref_ctnbl()) {
+    // expected-warning at -1{{Local variable 'd' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
+    d->method();
+  }
+  do {
+    RefCountable *e = provide_ref_ctnbl();
+    // expected-warning at -1{{Local variable 'e' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
+    e->method();
+  } while (1);
+  someFunction();
+}
+
 } // namespace ignore_for_if

>From 7140c6265f0431bdfc12047c6e247821ae680b23 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rniwa at webkit.org>
Date: Wed, 28 Feb 2024 01:27:04 -0800
Subject: [PATCH 2/5] Limit the cache to if, for, while, and compound
 statements.

---
 .../Checkers/WebKit/PtrTypesSemantics.cpp     | 198 ++++++++----------
 .../WebKit/UncountedLocalVarsChecker.cpp      |   2 +-
 2 files changed, 90 insertions(+), 110 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 6c9a8aedb38a4c..709387e343b975 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -245,21 +245,12 @@ class TrivialFunctionAnalysisVisitor
 
   // Returns false if at least one child is non-trivial.
   bool VisitChildren(const Stmt *S) {
-    return withCachedResult(S, [&]() {
-      for (const Stmt *Child : S->children()) {
-        if (Child && !Visit(Child))
-          return false;
-      }
-      return true;
-    });
-  }
-
-  bool VisitSubExpr(const Expr *Parent, const Expr *E) {
-    return withCachedResult(Parent, [&]() {
-      if (!Visit(E))
+    for (const Stmt *Child : S->children()) {
+      if (Child && !Visit(Child))
         return false;
-      return true;
-    });
+    }
+
+    return true;
   }
 
   template <typename StmtType, typename CheckFunction>
@@ -290,53 +281,54 @@ class TrivialFunctionAnalysisVisitor
   bool VisitCompoundStmt(const CompoundStmt *CS) {
     // A compound statement is allowed as long each individual sub-statement
     // is trivial.
-    return VisitChildren(CS);
+    return withCachedResult(CS, [&]() { return VisitChildren(CS); });
   }
 
   bool VisitReturnStmt(const ReturnStmt *RS) {
     // A return statement is allowed as long as the return value is trivial.
-    return withCachedResult(RS, [&]() {
-      if (auto *RV = RS->getRetValue())
-        return Visit(RV);
-      return true;
-    });
-  }
-
-  bool VisitCXXForRangeStmt(const CXXForRangeStmt *FS) {
-    return VisitChildren(FS);
+    if (auto *RV = RS->getRetValue())
+      return Visit(RV);
+    return true;
   }
 
   bool VisitDeclStmt(const DeclStmt *DS) { return VisitChildren(DS); }
   bool VisitDoStmt(const DoStmt *DS) { return VisitChildren(DS); }
-  bool VisitForStmt(const ForStmt *FS) { return VisitChildren(FS); }
-  bool VisitWhileStmt(const WhileStmt *WS) { return VisitChildren(WS); }
-  bool VisitIfStmt(const IfStmt *IS) { return VisitChildren(IS); }
+  bool VisitIfStmt(const IfStmt *IS) {
+    return withCachedResult(IS, [&]() { return VisitChildren(IS); });
+  }
+  bool VisitForStmt(const ForStmt *FS) {
+    return withCachedResult(FS, [&]() { return VisitChildren(FS); });
+  }
+  bool VisitCXXForRangeStmt(const CXXForRangeStmt *FS) {
+    return withCachedResult(FS, [&]() { return VisitChildren(FS); });
+  }
+  bool VisitWhileStmt(const WhileStmt *WS) {
+    return withCachedResult(WS, [&]() { return VisitChildren(WS); });
+  }
   bool VisitSwitchStmt(const SwitchStmt *SS) { return VisitChildren(SS); }
   bool VisitCaseStmt(const CaseStmt *CS) { return VisitChildren(CS); }
   bool VisitDefaultStmt(const DefaultStmt *DS) { return VisitChildren(DS); }
 
   bool VisitUnaryOperator(const UnaryOperator *UO) {
     // Operator '*' and '!' are allowed as long as the operand is trivial.
-    return withCachedResult(UO, [&]() {
-      auto op = UO->getOpcode();
-      if (op == UO_Deref || op == UO_AddrOf || op == UO_LNot)
-        return Visit(UO->getSubExpr());
-      if (UO->isIncrementOp() || UO->isDecrementOp()) {
-        if (auto *RefExpr = dyn_cast<DeclRefExpr>(UO->getSubExpr())) {
-          if (auto *Decl = dyn_cast<VarDecl>(RefExpr->getDecl()))
-            return Decl->isLocalVarDeclOrParm() &&
-                   Decl->getType().isPODType(Decl->getASTContext());
-        }
+    auto op = UO->getOpcode();
+    if (op == UO_Deref || op == UO_AddrOf || op == UO_LNot)
+      return Visit(UO->getSubExpr());
+    if (UO->isIncrementOp() || UO->isDecrementOp()) {
+      // Allow increment or decrement of a POD type.
+      if (auto *RefExpr = dyn_cast<DeclRefExpr>(UO->getSubExpr())) {
+        if (auto *Decl = dyn_cast<VarDecl>(RefExpr->getDecl()))
+          return Decl->isLocalVarDeclOrParm() &&
+                 Decl->getType().isPODType(Decl->getASTContext());
       }
-      // Other operators are non-trivial.
-      return false;
-    });
+    }
+    // Other operators are non-trivial.
+    return false;
   }
 
   bool VisitBinaryOperator(const BinaryOperator *BO) {
     // Binary operators are trivial if their operands are trivial.
-    return withCachedResult(
-        BO, [&]() { return Visit(BO->getLHS()) && Visit(BO->getRHS()); });
+    return Visit(BO->getLHS()) && Visit(BO->getRHS());
   }
 
   bool VisitConditionalOperator(const ConditionalOperator *CO) {
@@ -345,21 +337,19 @@ class TrivialFunctionAnalysisVisitor
   }
 
   bool VisitDeclRefExpr(const DeclRefExpr *DRE) {
-    return withCachedResult(DRE, [&]() {
-      if (auto *decl = DRE->getDecl()) {
-        if (isa<ParmVarDecl>(decl))
-          return true;
-        if (isa<EnumConstantDecl>(decl))
+    if (auto *decl = DRE->getDecl()) {
+      if (isa<ParmVarDecl>(decl))
+        return true;
+      if (isa<EnumConstantDecl>(decl))
+        return true;
+      if (auto *VD = dyn_cast<VarDecl>(decl)) {
+        if (VD->hasConstantInitialization() && VD->getEvaluatedValue())
           return true;
-        if (auto *VD = dyn_cast<VarDecl>(decl)) {
-          if (VD->hasConstantInitialization() && VD->getEvaluatedValue())
-            return true;
-          auto *Init = VD->getInit();
-          return !Init || Visit(Init);
-        }
+        auto *Init = VD->getInit();
+        return !Init || Visit(Init);
       }
-      return false;
-    });
+    }
+    return false;
   }
 
   bool VisitAtomicExpr(const AtomicExpr *E) { return VisitChildren(E); }
@@ -370,23 +360,21 @@ class TrivialFunctionAnalysisVisitor
   }
 
   bool VisitCallExpr(const CallExpr *CE) {
-    return withCachedResult(CE, [&]() {
-      if (!checkArguments(CE))
-        return false;
+    if (!checkArguments(CE))
+      return false;
 
-      auto *Callee = CE->getDirectCallee();
-      if (!Callee)
-        return false;
-      const auto &Name = safeGetName(Callee);
+    auto *Callee = CE->getDirectCallee();
+    if (!Callee)
+      return false;
+    const auto &Name = safeGetName(Callee);
 
-      if (Name == "WTFCrashWithInfo" || Name == "WTFBreakpointTrap" ||
-          Name == "WTFReportAssertionFailure" ||
-          Name == "compilerFenceForCrash" || Name == "__builtin_unreachable")
-        return true;
+    if (Name == "WTFCrashWithInfo" || Name == "WTFBreakpointTrap" ||
+        Name == "WTFReportAssertionFailure" ||
+        Name == "compilerFenceForCrash" || Name == "__builtin_unreachable")
+      return true;
 
-      return TrivialFunctionAnalysis::isTrivialImpl(Callee, FunctionCache,
-                                                    StatementCache);
-    });
+    return TrivialFunctionAnalysis::isTrivialImpl(Callee, FunctionCache,
+                                                  StatementCache);
   }
 
   bool VisitPredefinedExpr(const PredefinedExpr *E) {
@@ -395,26 +383,24 @@ class TrivialFunctionAnalysisVisitor
   }
 
   bool VisitCXXMemberCallExpr(const CXXMemberCallExpr *MCE) {
-    return withCachedResult(MCE, [&]() {
-      if (!checkArguments(MCE))
-        return false;
+    if (!checkArguments(MCE))
+      return false;
 
-      bool TrivialThis = Visit(MCE->getImplicitObjectArgument());
-      if (!TrivialThis)
-        return false;
+    bool TrivialThis = Visit(MCE->getImplicitObjectArgument());
+    if (!TrivialThis)
+      return false;
 
-      auto *Callee = MCE->getMethodDecl();
-      if (!Callee)
-        return false;
+    auto *Callee = MCE->getMethodDecl();
+    if (!Callee)
+      return false;
 
-      std::optional<bool> IsGetterOfRefCounted = isGetterOfRefCounted(Callee);
-      if (IsGetterOfRefCounted && *IsGetterOfRefCounted)
-        return true;
+    std::optional<bool> IsGetterOfRefCounted = isGetterOfRefCounted(Callee);
+    if (IsGetterOfRefCounted && *IsGetterOfRefCounted)
+      return true;
 
-      // Recursively descend into the callee to confirm it's trivial as well.
-      return TrivialFunctionAnalysis::isTrivialImpl(Callee, FunctionCache,
-                                                    StatementCache);
-    });
+    // Recursively descend into the callee to confirm it's trivial as well.
+    return TrivialFunctionAnalysis::isTrivialImpl(Callee, FunctionCache,
+                                                  StatementCache);
   }
 
   bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) {
@@ -434,51 +420,45 @@ class TrivialFunctionAnalysisVisitor
   }
 
   bool VisitCXXConstructExpr(const CXXConstructExpr *CE) {
-    return withCachedResult(CE, [&]() {
-      for (const Expr *Arg : CE->arguments()) {
-        if (Arg && !Visit(Arg))
-          return false;
-      }
+    for (const Expr *Arg : CE->arguments()) {
+      if (Arg && !Visit(Arg))
+        return false;
+    }
 
-      // Recursively descend into the callee to confirm that it's trivial.
-      return TrivialFunctionAnalysis::isTrivialImpl(
-          CE->getConstructor(), FunctionCache, StatementCache);
-    });
+    // Recursively descend into the callee to confirm that it's trivial.
+    return TrivialFunctionAnalysis::isTrivialImpl(
+        CE->getConstructor(), FunctionCache, StatementCache);
   }
 
   bool VisitImplicitCastExpr(const ImplicitCastExpr *ICE) {
-    return VisitSubExpr(ICE, ICE->getSubExpr());
+    return Visit(ICE->getSubExpr());
   }
 
   bool VisitExplicitCastExpr(const ExplicitCastExpr *ECE) {
-    return VisitSubExpr(ECE, ECE->getSubExpr());
+    return Visit(ECE->getSubExpr());
   }
 
   bool VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *VMT) {
-    return VisitSubExpr(VMT, VMT->getSubExpr());
+    return Visit(VMT->getSubExpr());
   }
 
   bool VisitExprWithCleanups(const ExprWithCleanups *EWC) {
-    return VisitSubExpr(EWC, EWC->getSubExpr());
+    return Visit(EWC->getSubExpr());
   }
 
-  bool VisitParenExpr(const ParenExpr *PE) {
-    return VisitSubExpr(PE, PE->getSubExpr());
-  }
+  bool VisitParenExpr(const ParenExpr *PE) { return Visit(PE->getSubExpr()); }
 
   bool VisitInitListExpr(const InitListExpr *ILE) {
-    return withCachedResult(ILE, [&]() {
-      for (const Expr *Child : ILE->inits()) {
-        if (Child && !Visit(Child))
-          return false;
-      }
-      return true;
-    });
+    for (const Expr *Child : ILE->inits()) {
+      if (Child && !Visit(Child))
+        return false;
+    }
+    return true;
   }
 
   bool VisitMemberExpr(const MemberExpr *ME) {
     // Field access is allowed but the base pointer may itself be non-trivial.
-    return VisitSubExpr(ME, ME->getBase());
+    return Visit(ME->getBase());
   }
 
   bool VisitCXXThisExpr(const CXXThisExpr *CTE) {
diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
index 4068b472cc5fcd..b9465a2668605c 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
@@ -176,7 +176,6 @@ class UncountedLocalVarsChecker
       const clang::Expr *const InitArgOrigin =
           tryToFindPtrOrigin(InitExpr, /*StopAtFirstRefCountedObj=*/false)
               .first;
-
       if (!InitArgOrigin)
         return;
 
@@ -199,6 +198,7 @@ class UncountedLocalVarsChecker
                 return;
             }
           }
+
           // Parameters are guaranteed to be safe for the duration of the call
           // by another checker.
           if (isa<ParmVarDecl>(MaybeGuardian))

>From 7b54a4074736249e6c7304de066d6fbf9b462080 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rniwa at webkit.org>
Date: Thu, 29 Feb 2024 23:17:38 -0800
Subject: [PATCH 3/5] Merge function and statement caches.

---
 .../Checkers/WebKit/PtrTypesSemantics.cpp     | 44 ++++++++-----------
 .../Checkers/WebKit/PtrTypesSemantics.h       | 19 ++++----
 2 files changed, 26 insertions(+), 37 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 709387e343b975..3202f8eb8bf131 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -256,7 +256,7 @@ class TrivialFunctionAnalysisVisitor
   template <typename StmtType, typename CheckFunction>
   bool withCachedResult(const StmtType *S, CheckFunction Function) {
     // Insert false to the cache first to avoid infinite recursion.
-    auto [It, IsNew] = StatementCache.insert(std::make_pair(S, false));
+    auto [It, IsNew] = Cache.insert(std::make_pair(S, false));
     if (!IsNew)
       return It->second;
     bool Result = Function();
@@ -265,12 +265,9 @@ class TrivialFunctionAnalysisVisitor
   }
 
 public:
-  using FunctionCacheTy = TrivialFunctionAnalysis::FunctionCacheTy;
-  using StatementCacheTy = TrivialFunctionAnalysis::StatementCacheTy;
+  using CacheTy = TrivialFunctionAnalysis::CacheTy;
 
-  TrivialFunctionAnalysisVisitor(FunctionCacheTy &FunctionCache,
-                                 StatementCacheTy &StatementCache)
-      : FunctionCache(FunctionCache), StatementCache(StatementCache) {}
+  TrivialFunctionAnalysisVisitor(CacheTy &Cache) : Cache(Cache) {}
 
   bool VisitStmt(const Stmt *S) {
     // All statements are non-trivial unless overriden later.
@@ -314,6 +311,7 @@ class TrivialFunctionAnalysisVisitor
     auto op = UO->getOpcode();
     if (op == UO_Deref || op == UO_AddrOf || op == UO_LNot)
       return Visit(UO->getSubExpr());
+
     if (UO->isIncrementOp() || UO->isDecrementOp()) {
       // Allow increment or decrement of a POD type.
       if (auto *RefExpr = dyn_cast<DeclRefExpr>(UO->getSubExpr())) {
@@ -373,8 +371,7 @@ class TrivialFunctionAnalysisVisitor
         Name == "compilerFenceForCrash" || Name == "__builtin_unreachable")
       return true;
 
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, FunctionCache,
-                                                  StatementCache);
+    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
   }
 
   bool VisitPredefinedExpr(const PredefinedExpr *E) {
@@ -398,9 +395,8 @@ class TrivialFunctionAnalysisVisitor
     if (IsGetterOfRefCounted && *IsGetterOfRefCounted)
       return true;
 
-    // Recursively descend into the callee to confirm it's trivial as well.
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, FunctionCache,
-                                                  StatementCache);
+    // Recursively descend into the callee to confirm that it's trivial as well.
+    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
   }
 
   bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) {
@@ -426,8 +422,7 @@ class TrivialFunctionAnalysisVisitor
     }
 
     // Recursively descend into the callee to confirm that it's trivial.
-    return TrivialFunctionAnalysis::isTrivialImpl(
-        CE->getConstructor(), FunctionCache, StatementCache);
+    return TrivialFunctionAnalysis::isTrivialImpl(CE->getConstructor(), Cache);
   }
 
   bool VisitImplicitCastExpr(const ImplicitCastExpr *ICE) {
@@ -484,18 +479,16 @@ class TrivialFunctionAnalysisVisitor
   }
 
 private:
-  FunctionCacheTy FunctionCache;
-  StatementCacheTy StatementCache;
+  CacheTy Cache;
 };
 
 bool TrivialFunctionAnalysis::isTrivialImpl(
-    const Decl *D, TrivialFunctionAnalysis::FunctionCacheTy &FunctionCache,
-    TrivialFunctionAnalysis::StatementCacheTy &StatementCache) {
+    const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache) {
   // If the function isn't in the cache, conservatively assume that
   // it's not trivial until analysis completes. This makes every recursive
   // function non-trivial. This also guarantees that each function
   // will be scanned at most once.
-  auto [It, IsNew] = FunctionCache.insert(std::make_pair(D, false));
+  auto [It, IsNew] = Cache.insert(std::make_pair(D, false));
   if (!IsNew)
     return It->second;
 
@@ -503,29 +496,28 @@ bool TrivialFunctionAnalysis::isTrivialImpl(
   if (!Body)
     return false;
 
-  TrivialFunctionAnalysisVisitor V(FunctionCache, StatementCache);
+  TrivialFunctionAnalysisVisitor V(Cache);
   bool Result = V.Visit(Body);
   if (Result)
-    FunctionCache[D] = true;
+    Cache[D] = true;
 
   return Result;
 }
 
 bool TrivialFunctionAnalysis::isTrivialImpl(
-    const Stmt *S, TrivialFunctionAnalysis::FunctionCacheTy &FunctionCache,
-    TrivialFunctionAnalysis::StatementCacheTy &StatementCache) {
+    const Stmt *S, TrivialFunctionAnalysis::CacheTy &Cache) {
   // If the statement isn't in the cache, conservatively assume that
   // it's not trivial until analysis completes. Unlike a function case,
   // we don't insert an entry into the cache until Visit returns
   // since Visit* functions themselves make use of the cache.
 
-  auto It = StatementCache.find(S);
-  if (It != StatementCache.end())
+  auto It = Cache.find(S);
+  if (It != Cache.end())
     return It->second;
 
-  TrivialFunctionAnalysisVisitor V(FunctionCache, StatementCache);
+  TrivialFunctionAnalysisVisitor V(Cache);
   bool Result = V.Visit(S);
-  StatementCache[S] = Result;
+  Cache[S] = Result;
 
   return Result;
 }
diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
index 3f4cdd1f2ffb02..e503bff73f4b9f 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
@@ -11,6 +11,7 @@
 
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/PointerUnion.h"
 #include <optional>
 
 namespace clang {
@@ -72,26 +73,22 @@ class TrivialFunctionAnalysis {
 public:
   /// \returns true if \p D is a "trivial" function.
   bool isTrivial(const Decl *D) const {
-    return isTrivialImpl(D, TheFunctionCache, TheStatementCache);
+    return isTrivialImpl(D, TheCache);
   }
 
   bool isTrivial(const Stmt *S) const {
-    return isTrivialImpl(S, TheFunctionCache, TheStatementCache);
+    return isTrivialImpl(S, TheCache);
   }
 
 private:
   friend class TrivialFunctionAnalysisVisitor;
 
-  using FunctionCacheTy = llvm::DenseMap<const Decl *, bool>;
-  mutable FunctionCacheTy TheFunctionCache{};
+  using CacheTy = llvm::DenseMap<llvm::PointerUnion<const Decl *,
+      const Stmt *>, bool>;
+  mutable CacheTy TheCache{};
 
-  using StatementCacheTy = llvm::DenseMap<const Stmt *, bool>;
-  mutable StatementCacheTy TheStatementCache{};
-
-  static bool isTrivialImpl(const Decl *D, FunctionCacheTy &FunctionCache,
-                            StatementCacheTy &StatementCache);
-  static bool isTrivialImpl(const Stmt *S, FunctionCacheTy &FunctionCache,
-                            StatementCacheTy &StatementCache);
+  static bool isTrivialImpl(const Decl *D, CacheTy &Cache);
+  static bool isTrivialImpl(const Stmt *S, CacheTy &Cache);
 };
 
 } // namespace clang

>From 285f64eb9f1a55b8cd4f17f3504a3d57353d3a15 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rniwa at webkit.org>
Date: Thu, 29 Feb 2024 23:53:45 -0800
Subject: [PATCH 4/5] Intercept Traverse* functions to find trivial statements
 Instead of traversing the AST context up to find trivial statements.

---
 .../WebKit/UncountedLocalVarsChecker.cpp      | 56 ++++++++++++-------
 .../Checkers/WebKit/uncounted-local-vars.cpp  |  1 -
 2 files changed, 36 insertions(+), 21 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
index b9465a2668605c..5dc406c79c10fa 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
@@ -111,8 +111,6 @@ class UncountedLocalVarsChecker
               "WebKit coding guidelines"};
   mutable BugReporter *BR;
 
-  TrivialFunctionAnalysis TFA;
-
 public:
   void checkASTDecl(const TranslationUnitDecl *TUD, AnalysisManager &MGR,
                     BugReporter &BRArg) const {
@@ -123,6 +121,11 @@ class UncountedLocalVarsChecker
     // want to visit those, so we make our own RecursiveASTVisitor.
     struct LocalVisitor : public RecursiveASTVisitor<LocalVisitor> {
       const UncountedLocalVarsChecker *Checker;
+
+      TrivialFunctionAnalysis TFA;
+
+      using Base = RecursiveASTVisitor<LocalVisitor>;
+
       explicit LocalVisitor(const UncountedLocalVarsChecker *Checker)
           : Checker(Checker) {
         assert(Checker);
@@ -135,6 +138,37 @@ class UncountedLocalVarsChecker
         Checker->visitVarDecl(V);
         return true;
       }
+
+      bool TraverseIfStmt(IfStmt *IS) {
+        if (!TFA.isTrivial(IS))
+          return Base::TraverseIfStmt(IS);
+        return true;
+      }
+
+      bool TraverseForStmt(ForStmt *FS) {
+        if (!TFA.isTrivial(FS))
+          return Base::TraverseForStmt(FS);
+        return true;
+      }
+
+      bool TraverseCXXForRangeStmt(CXXForRangeStmt *FRS) {
+        if (!TFA.isTrivial(FRS))
+          return Base::TraverseCXXForRangeStmt(FRS);
+        return true;
+      }
+
+      bool TraverseWhileStmt(WhileStmt *WS) {
+        if (!TFA.isTrivial(WS))
+          return Base::TraverseWhileStmt(WS);
+        return true;
+      }
+
+      bool TraverseCompoundStmt(CompoundStmt *CS) {
+        if (!TFA.isTrivial(CS))
+          return Base::TraverseCompoundStmt(CS);
+        return true;
+      }
+
     };
 
     LocalVisitor visitor(this);
@@ -151,24 +185,6 @@ class UncountedLocalVarsChecker
 
     std::optional<bool> IsUncountedPtr = isUncountedPtr(ArgType);
     if (IsUncountedPtr && *IsUncountedPtr) {
-
-      ASTContext &ctx = V->getASTContext();
-      for (DynTypedNodeList ancestors = ctx.getParents(*V); !ancestors.empty();
-           ancestors = ctx.getParents(*ancestors.begin())) {
-        for (auto &ancestor : ancestors) {
-          if (auto *S = ancestor.get<IfStmt>(); S && TFA.isTrivial(S))
-            return;
-          if (auto *S = ancestor.get<ForStmt>(); S && TFA.isTrivial(S))
-            return;
-          if (auto *S = ancestor.get<CXXForRangeStmt>(); S && TFA.isTrivial(S))
-            return;
-          if (auto *S = ancestor.get<WhileStmt>(); S && TFA.isTrivial(S))
-            return;
-          if (auto *S = ancestor.get<CompoundStmt>(); S && TFA.isTrivial(S))
-            return;
-        }
-      }
-
       const Expr *const InitExpr = V->getInit();
       if (!InitExpr)
         return; // FIXME: later on we might warn on uninitialized vars too
diff --git a/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp b/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp
index 3fe04f775fbbcb..d58f0bf04921e0 100644
--- a/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp
@@ -150,7 +150,6 @@ void foo() {
 }
 
 void bar() {
-  // no warnings
   if (RefCountable *a = provide_ref_ctnbl()) {
     // expected-warning at -1{{Local variable 'a' is uncounted and unsafe [alpha.webkit.UncountedLocalVarsChecker]}}
     a->method();    

>From 3a3448fd3c06da920e8f972bf2dfa1a9284aa689 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rniwa at webkit.org>
Date: Fri, 1 Mar 2024 00:00:56 -0800
Subject: [PATCH 5/5] Fix formatting.

---
 .../Checkers/WebKit/PtrTypesSemantics.cpp           | 12 ++++++------
 .../Checkers/WebKit/PtrTypesSemantics.h             | 13 ++++---------
 .../Checkers/WebKit/UncountedLocalVarsChecker.cpp   |  1 -
 3 files changed, 10 insertions(+), 16 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 3202f8eb8bf131..bec10409007eb6 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -254,7 +254,7 @@ class TrivialFunctionAnalysisVisitor
   }
 
   template <typename StmtType, typename CheckFunction>
-  bool withCachedResult(const StmtType *S, CheckFunction Function) {
+  bool WithCachedResult(const StmtType *S, CheckFunction Function) {
     // Insert false to the cache first to avoid infinite recursion.
     auto [It, IsNew] = Cache.insert(std::make_pair(S, false));
     if (!IsNew)
@@ -278,7 +278,7 @@ class TrivialFunctionAnalysisVisitor
   bool VisitCompoundStmt(const CompoundStmt *CS) {
     // A compound statement is allowed as long each individual sub-statement
     // is trivial.
-    return withCachedResult(CS, [&]() { return VisitChildren(CS); });
+    return WithCachedResult(CS, [&]() { return VisitChildren(CS); });
   }
 
   bool VisitReturnStmt(const ReturnStmt *RS) {
@@ -291,16 +291,16 @@ class TrivialFunctionAnalysisVisitor
   bool VisitDeclStmt(const DeclStmt *DS) { return VisitChildren(DS); }
   bool VisitDoStmt(const DoStmt *DS) { return VisitChildren(DS); }
   bool VisitIfStmt(const IfStmt *IS) {
-    return withCachedResult(IS, [&]() { return VisitChildren(IS); });
+    return WithCachedResult(IS, [&]() { return VisitChildren(IS); });
   }
   bool VisitForStmt(const ForStmt *FS) {
-    return withCachedResult(FS, [&]() { return VisitChildren(FS); });
+    return WithCachedResult(FS, [&]() { return VisitChildren(FS); });
   }
   bool VisitCXXForRangeStmt(const CXXForRangeStmt *FS) {
-    return withCachedResult(FS, [&]() { return VisitChildren(FS); });
+    return WithCachedResult(FS, [&]() { return VisitChildren(FS); });
   }
   bool VisitWhileStmt(const WhileStmt *WS) {
-    return withCachedResult(WS, [&]() { return VisitChildren(WS); });
+    return WithCachedResult(WS, [&]() { return VisitChildren(WS); });
   }
   bool VisitSwitchStmt(const SwitchStmt *SS) { return VisitChildren(SS); }
   bool VisitCaseStmt(const CaseStmt *CS) { return VisitChildren(CS); }
diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
index e503bff73f4b9f..9ed8e7cab6abb9 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.h
@@ -72,19 +72,14 @@ bool isSingleton(const FunctionDecl *F);
 class TrivialFunctionAnalysis {
 public:
   /// \returns true if \p D is a "trivial" function.
-  bool isTrivial(const Decl *D) const {
-    return isTrivialImpl(D, TheCache);
-  }
-
-  bool isTrivial(const Stmt *S) const {
-    return isTrivialImpl(S, TheCache);
-  }
+  bool isTrivial(const Decl *D) const { return isTrivialImpl(D, TheCache); }
+  bool isTrivial(const Stmt *S) const { return isTrivialImpl(S, TheCache); }
 
 private:
   friend class TrivialFunctionAnalysisVisitor;
 
-  using CacheTy = llvm::DenseMap<llvm::PointerUnion<const Decl *,
-      const Stmt *>, bool>;
+  using CacheTy =
+      llvm::DenseMap<llvm::PointerUnion<const Decl *, const Stmt *>, bool>;
   mutable CacheTy TheCache{};
 
   static bool isTrivialImpl(const Decl *D, CacheTy &Cache);
diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
index 5dc406c79c10fa..6036ad58cf253c 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/UncountedLocalVarsChecker.cpp
@@ -168,7 +168,6 @@ class UncountedLocalVarsChecker
           return Base::TraverseCompoundStmt(CS);
         return true;
       }
-
     };
 
     LocalVisitor visitor(this);



More information about the cfe-commits mailing list