[clang] [analyzer] Add std::variant checker (PR #66481)

Gábor Spaits via cfe-commits cfe-commits at lists.llvm.org
Sun Oct 22 10:10:22 PDT 2023


https://github.com/spaits updated https://github.com/llvm/llvm-project/pull/66481

>From 1948d226de16bda2899ca562276370d20ceba236 Mon Sep 17 00:00:00 2001
From: Gabor Spaits <gaborspaits1 at gmail.com>
Date: Fri, 15 Sep 2023 10:21:30 +0200
Subject: [PATCH 1/5] [analyzer] Add std::variant checker

Adding a checker that checks for bad std::variant type access.
---
 .../clang/StaticAnalyzer/Checkers/Checkers.td |   4 +
 .../StaticAnalyzer/Checkers/CMakeLists.txt    |   1 +
 .../Checkers/StdVariantChecker.cpp            | 327 ++++++++++++++++
 .../Checkers/TaggedUnionModeling.h            | 104 +++++
 .../Inputs/system-header-simulator-cxx.h      | 122 ++++++
 .../diagnostics/explicit-suppression.cpp      |   2 +-
 clang/test/Analysis/std-variant-checker.cpp   | 358 ++++++++++++++++++
 7 files changed, 917 insertions(+), 1 deletion(-)
 create mode 100644 clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
 create mode 100644 clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
 create mode 100644 clang/test/Analysis/std-variant-checker.cpp

diff --git a/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td b/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
index be813bde8be41ea..a93e97348606f28 100644
--- a/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
+++ b/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
@@ -318,6 +318,10 @@ def C11LockChecker : Checker<"C11Lock">,
   Dependencies<[PthreadLockBase]>,
   Documentation<HasDocumentation>;
 
+def StdVariantChecker : Checker<"StdVariant">,
+  HelpText<"Check for bad type access for std::variant.">,
+  Documentation<NotDocumented>;
+
 } // end "alpha.core"
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt b/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt
index d849649c96a0d13..4443ffd09293881 100644
--- a/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt
+++ b/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt
@@ -108,6 +108,7 @@ add_clang_library(clangStaticAnalyzerCheckers
   SmartPtrModeling.cpp
   StackAddrEscapeChecker.cpp
   StdLibraryFunctionsChecker.cpp
+  StdVariantChecker.cpp
   STLAlgorithmModeling.cpp
   StreamChecker.cpp
   StringChecker.cpp
diff --git a/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
new file mode 100644
index 000000000000000..680c5567431bbfb
--- /dev/null
+++ b/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
@@ -0,0 +1,327 @@
+//===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/Type.h"
+#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
+#include "clang/StaticAnalyzer/Core/Checker.h"
+#include "clang/StaticAnalyzer/Core/CheckerManager.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h"
+#include "llvm/ADT/FoldingSet.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <optional>
+#include <string_view>
+
+#include "TaggedUnionModeling.h"
+
+using namespace clang;
+using namespace ento;
+using namespace tagged_union_modeling;
+
+REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType)
+
+namespace clang {
+namespace ento {
+namespace tagged_union_modeling {
+
+// Returns the CallEvent representing the caller of the function
+// It is needed because the CallEvent class does not contain enough information
+// to tell who called it. Checker context is needed.
+CallEventRef<> getCaller(const CallEvent &Call, const ProgramStateRef &State) {
+  const auto *CallLocationContext = Call.getLocationContext();
+  if (!CallLocationContext || CallLocationContext->inTopFrame())
+    return nullptr;
+
+  const auto *CallStackFrameContext = CallLocationContext->getStackFrame();
+  if (!CallStackFrameContext)
+    return nullptr;
+
+  CallEventManager &CEMgr = State->getStateManager().getCallEventManager();
+  return CEMgr.getCaller(CallStackFrameContext, State);
+}
+
+const CXXConstructorDecl *
+getConstructorDeclarationForCall(const CallEvent &Call) {
+  const auto *ConstructorCall = dyn_cast<CXXConstructorCall>(&Call);
+  if (!ConstructorCall)
+    return nullptr;
+
+  return ConstructorCall->getDecl();
+}
+
+bool isCopyConstructorCall(const CallEvent &Call) {
+  if (const CXXConstructorDecl *ConstructorDecl =
+          getConstructorDeclarationForCall(Call))
+    return ConstructorDecl->isCopyConstructor();
+  return false;
+}
+
+bool isCopyAssignmentCall(const CallEvent &Call) {
+  const Decl *CopyAssignmentDecl = Call.getDecl();
+
+  if (const auto *AsMethodDecl =
+          dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl))
+    return AsMethodDecl->isCopyAssignmentOperator();
+  return false;
+}
+
+bool isMoveConstructorCall(const CallEvent &Call) {
+  const CXXConstructorDecl *ConstructorDecl =
+      getConstructorDeclarationForCall(Call);
+  if (!ConstructorDecl)
+    return false;
+
+  return ConstructorDecl->isMoveConstructor();
+}
+
+bool isMoveAssignmentCall(const CallEvent &Call) {
+  const Decl *CopyAssignmentDecl = Call.getDecl();
+
+  const auto *AsMethodDecl =
+      dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl);
+  if (!AsMethodDecl)
+    return false;
+
+  return AsMethodDecl->isMoveAssignmentOperator();
+}
+
+bool isStdType(const Type *Type, llvm::StringRef TypeName) {
+  auto *Decl = Type->getAsRecordDecl();
+  if (!Decl)
+    return false;
+  return (Decl->getName() == TypeName) && Decl->isInStdNamespace();
+}
+
+bool isStdVariant(const Type *Type) {
+  return isStdType(Type, llvm::StringLiteral("variant"));
+}
+
+bool calledFromSystemHeader(const CallEvent &Call,
+                            const ProgramStateRef &State) {
+  if (CallEventRef<> Caller = getCaller(Call, State))
+    return Caller->isInSystemHeader();
+
+  return false;
+}
+
+bool calledFromSystemHeader(const CallEvent &Call, CheckerContext &C) {
+  return calledFromSystemHeader(Call, C.getState());
+}
+
+} // end of namespace tagged_union_modeling
+} // end of namespace ento
+} // end of namespace clang
+
+static std::optional<ArrayRef<TemplateArgument>>
+getTemplateArgsFromVariant(const Type *VariantType) {
+  const auto *TempSpecType = VariantType->getAs<TemplateSpecializationType>();
+  if (!TempSpecType)
+    return {};
+
+  return TempSpecType->template_arguments();
+}
+
+static std::optional<QualType>
+getNthTemplateTypeArgFromVariant(const Type *varType, unsigned i) {
+  std::optional<ArrayRef<TemplateArgument>> VariantTemplates =
+      getTemplateArgsFromVariant(varType);
+  if (!VariantTemplates)
+    return {};
+
+  return (*VariantTemplates)[i].getAsType();
+}
+
+static bool isVowel(char a) {
+  switch (a) {
+  case 'a':
+  case 'e':
+  case 'i':
+  case 'o':
+  case 'u':
+    return true;
+  default:
+    return false;
+  }
+}
+
+static llvm::StringRef indefiniteArticleBasedOnVowel(char a) {
+  if (isVowel(a))
+    return "an";
+  return "a";
+}
+
+class StdVariantChecker : public Checker<eval::Call, check::RegionChanges> {
+  // Call descriptors to find relevant calls
+  CallDescription VariantConstructor{{"std", "variant", "variant"}};
+  CallDescription VariantAssignmentOperator{{"std", "variant", "operator="}};
+  CallDescription StdGet{{"std", "get"}, 1, 1};
+
+  BugType BadVariantType{this, "BadVariantType", "BadVariantType"};
+
+public:
+  ProgramStateRef checkRegionChanges(ProgramStateRef State,
+                                     const InvalidatedSymbols *,
+                                     ArrayRef<const MemRegion *>,
+                                     ArrayRef<const MemRegion *> Regions,
+                                     const LocationContext *,
+                                     const CallEvent *Call) const {
+    return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
+        Call, State, Regions);
+  }
+
+  bool evalCall(const CallEvent &Call, CheckerContext &C) const {
+    // Check if the call was not made from a system header. If it was then
+    // we do an early return because it is part of the implementation.
+    if (calledFromSystemHeader(Call, C))
+      return false;
+
+    if (StdGet.matches(Call))
+      return handleStdGetCall(Call, C);
+
+    // First check if a constructor call is happening. If it is a
+    // constructor call, check if it is an std::variant constructor call.
+    bool IsVariantConstructor =
+        isa<CXXConstructorCall>(Call) && VariantConstructor.matches(Call);
+    bool IsVariantAssignmentOperatorCall =
+        isa<CXXMemberOperatorCall>(Call) &&
+        VariantAssignmentOperator.matches(Call);
+
+    if (IsVariantConstructor || IsVariantAssignmentOperatorCall) {
+      if (Call.getNumArgs() == 0 && IsVariantConstructor) {
+        handleDefaultConstructor(cast<CXXConstructorCall>(&Call), C);
+        return true;
+      }
+
+      // FIXME Later this checker should be extended to handle constructors
+      // with multiple arguments.
+      if (Call.getNumArgs() != 1)
+        return false;
+
+      SVal ThisSVal;
+      if (IsVariantConstructor) {
+        const auto &AsConstructorCall = cast<CXXConstructorCall>(Call);
+        ThisSVal = AsConstructorCall.getCXXThisVal();
+      } else if (IsVariantAssignmentOperatorCall) {
+        const auto &AsMemberOpCall = cast<CXXMemberOperatorCall>(Call);
+        ThisSVal = AsMemberOpCall.getCXXThisVal();
+      } else {
+        return false;
+      }
+
+      handleConstructorAndAssignment<VariantHeldTypeMap>(Call, C, ThisSVal);
+      return true;
+    }
+    return false;
+  }
+
+private:
+  // The default constructed std::variant must be handled separately
+  // by default the std::variant is going to hold a default constructed instance
+  // of the first type of the possible types
+  void handleDefaultConstructor(const CXXConstructorCall *ConstructorCall,
+                                CheckerContext &C) const {
+    SVal ThisSVal = ConstructorCall->getCXXThisVal();
+
+    const auto *const ThisMemRegion = ThisSVal.getAsRegion();
+    if (!ThisMemRegion)
+      return;
+
+    std::optional<QualType> DefaultType = getNthTemplateTypeArgFromVariant(
+        ThisSVal.getType(C.getASTContext())->getPointeeType().getTypePtr(), 0);
+    if (!DefaultType)
+      return;
+
+    ProgramStateRef State = ConstructorCall->getState();
+    State = State->set<VariantHeldTypeMap>(ThisMemRegion, *DefaultType);
+    C.addTransition(State);
+  }
+
+  bool handleStdGetCall(const CallEvent &Call, CheckerContext &C) const {
+    ProgramStateRef State = Call.getState();
+
+    const auto &ArgType = Call.getArgSVal(0)
+                              .getType(C.getASTContext())
+                              ->getPointeeType()
+                              .getTypePtr();
+    // We have to make sure that the argument is an std::variant.
+    // There is another std::get with std::pair argument
+    if (!isStdVariant(ArgType))
+      return false;
+
+    // Get the mem region of the argument std::variant and look up the type
+    // information that we know about it.
+    const MemRegion *ArgMemRegion = Call.getArgSVal(0).getAsRegion();
+    const QualType *StoredType = State->get<VariantHeldTypeMap>(ArgMemRegion);
+    if (!StoredType)
+      return false;
+
+    const CallExpr *CE = cast<CallExpr>(Call.getOriginExpr());
+    const FunctionDecl *FD = CE->getDirectCallee();
+    if (FD->getTemplateSpecializationArgs()->size() < 1)
+      return false;
+
+    const auto &TypeOut = FD->getTemplateSpecializationArgs()->asArray()[0];
+    // std::get's first template parameter can be the type we want to get
+    // out of the std::variant or a natural number which is the position of
+    // the requested type in the argument type list of the std::variant's
+    // argument.
+    QualType RetrievedType;
+    switch (TypeOut.getKind()) {
+    case TemplateArgument::ArgKind::Type:
+      RetrievedType = TypeOut.getAsType();
+      break;
+    case TemplateArgument::ArgKind::Integral:
+      // In the natural number case we look up which type corresponds to the
+      // number.
+      if (std::optional<QualType> NthTemplate =
+              getNthTemplateTypeArgFromVariant(
+                  ArgType, TypeOut.getAsIntegral().getSExtValue())) {
+        RetrievedType = *NthTemplate;
+        break;
+      }
+      [[fallthrough]];
+    default:
+      return false;
+    }
+
+    QualType RetrievedCanonicalType = RetrievedType.getCanonicalType();
+    QualType StoredCanonicalType = StoredType->getCanonicalType();
+    if (RetrievedCanonicalType == StoredCanonicalType)
+      return true;
+
+    ExplodedNode *ErrNode = C.generateNonFatalErrorNode();
+    if (!ErrNode)
+      return false;
+    llvm::SmallString<128> Str;
+    llvm::raw_svector_ostream OS(Str);
+    std::string StoredTypeName = StoredType->getAsString();
+    std::string RetrievedTypeName = RetrievedType.getAsString();
+    OS << "std::variant " << ArgMemRegion->getDescriptiveName() << " held "
+       << indefiniteArticleBasedOnVowel(StoredTypeName[0]) << " \'"
+       << StoredTypeName << "\', not "
+       << indefiniteArticleBasedOnVowel(RetrievedTypeName[0]) << " \'"
+       << RetrievedTypeName << "\'";
+    auto R = std::make_unique<PathSensitiveBugReport>(BadVariantType, OS.str(),
+                                                      ErrNode);
+    C.emitReport(std::move(R));
+    return true;
+  }
+};
+
+bool clang::ento::shouldRegisterStdVariantChecker(
+    clang::ento::CheckerManager const &mgr) {
+  return true;
+}
+
+void clang::ento::registerStdVariantChecker(clang::ento::CheckerManager &mgr) {
+  mgr.registerChecker<StdVariantChecker>();
+}
\ No newline at end of file
diff --git a/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h b/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
new file mode 100644
index 000000000000000..593f243e84ca686
--- /dev/null
+++ b/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
@@ -0,0 +1,104 @@
+//===- TaggedUnionModeling.h -------------------------------------*- C++ -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_STATICANALYZER_CHECKER_VARIANTLIKETYPEMODELING_H
+#define LLVM_CLANG_LIB_STATICANALYZER_CHECKER_VARIANTLIKETYPEMODELING_H
+
+#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
+#include "clang/StaticAnalyzer/Core/Checker.h"
+#include "clang/StaticAnalyzer/Core/CheckerManager.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
+#include "llvm/ADT/FoldingSet.h"
+#include <numeric>
+
+namespace clang {
+namespace ento {
+namespace tagged_union_modeling {
+
+// The implementation of all these functions can be found in the file
+// StdVariantChecker.cpp under the same directory as this file.
+CallEventRef<> getCaller(const CallEvent &Call, CheckerContext &C);
+bool isCopyConstructorCall(const CallEvent &Call);
+bool isCopyAssignmentCall(const CallEvent &Call);
+bool isMoveAssignmentCall(const CallEvent &Call);
+bool isMoveConstructorCall(const CallEvent &Call);
+bool isStdType(const Type *Type, const std::string &TypeName);
+bool isStdVariant(const Type *Type);
+bool calledFromSystemHeader(const CallEvent &Call, CheckerContext &C);
+
+// When invalidating regions, we also have to follow that by invalidating the
+// corresponding custom data in the program state.
+template <class TypeMap>
+ProgramStateRef
+removeInformationStoredForDeadInstances(const CallEvent *Call,
+                                        ProgramStateRef State,
+                                        ArrayRef<const MemRegion *> Regions) {
+  // If we do not know anything about the call we shall not continue.
+  // If the call is happens within a system header it is implementation detail.
+  // We should not take it into consideration.
+  if (!Call || Call->isInSystemHeader())
+    return State;
+
+  for (const MemRegion *Region : Regions)
+    State = State->remove<TypeMap>(Region);
+
+  return State;
+}
+
+template <class TypeMap>
+void handleConstructorAndAssignment(const CallEvent &Call, CheckerContext &C,
+                                    const SVal &ThisSVal) {
+  ProgramStateRef State = Call.getState();
+
+  if (!State)
+    return;
+
+  auto ArgSVal = Call.getArgSVal(0);
+  const auto *ThisRegion = ThisSVal.getAsRegion();
+  const auto *ArgMemRegion = ArgSVal.getAsRegion();
+
+  // Make changes to the state according to type of constructor/assignment
+  bool IsCopy = isCopyConstructorCall(Call) || isCopyAssignmentCall(Call);
+  bool IsMove = isMoveConstructorCall(Call) || isMoveAssignmentCall(Call);
+  // First we handle copy and move operations
+  if (IsCopy || IsMove) {
+    const QualType *OtherQType = State->get<TypeMap>(ArgMemRegion);
+
+    // If the argument of a copy constructor or assignment is unknown then
+    // we will not know the argument of the copied to object.
+    if (!OtherQType) {
+      State = State->remove<TypeMap>(ThisRegion);
+    } else {
+      // When move semantics is used we can only know that the moved from
+      // object must be in a destructible state. Other usage of the object
+      // than destruction is undefined.
+      if (IsMove)
+        State = State->remove<TypeMap>(ArgMemRegion);
+
+      State = State->set<TypeMap>(ThisRegion, *OtherQType);
+    }
+  } else {
+    // Value constructor
+    auto ArgQType = ArgSVal.getType(C.getASTContext());
+    const Type *ArgTypePtr = ArgQType.getTypePtr();
+
+    QualType WoPointer = ArgTypePtr->getPointeeType();
+    State = State->set<TypeMap>(ThisRegion, WoPointer);
+  }
+
+  C.addTransition(State);
+}
+
+} // namespace tagged_union_modeling
+} // namespace ento
+} // namespace clang
+
+#endif // LLVM_CLANG_LIB_STATICANALYZER_CHECKER_VARIANTLIKETYPEMODELING_H
\ No newline at end of file
diff --git a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h
index 8633a8beadbff33..3ef7af2ea6c6ab4 100644
--- a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h
+++ b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h
@@ -249,6 +249,11 @@ namespace std {
     pair(const pair<U1, U2> &other) : first(other.first),
                                       second(other.second) {}
   };
+
+  template<class T2, class T1>
+  T2& get(pair<T1, T2>& p) ;
+  template<class T1, class T2>
+  T1& get(const pair<T1, T2>& p) ;
   
   typedef __typeof__(sizeof(int)) size_t;
 
@@ -264,6 +269,9 @@ namespace std {
     return static_cast<RvalRef>(a);
   }
 
+  template< class T >
+  using remove_reference_t = typename remove_reference<T>::type;
+
   template <class T>
   void swap(T &a, T &b) {
     T c(std::move(a));
@@ -718,6 +726,11 @@ namespace std {
   template <class _Tp, class _Up> struct  is_same           : public false_type {};
   template <class _Tp>            struct  is_same<_Tp, _Tp> : public true_type {};
 
+  #if __cplusplus >= 201703L
+  template< class T, class U >
+  inline constexpr bool is_same_v = is_same<T, U>::value;
+  #endif
+
   template <class _Tp, bool = is_const<_Tp>::value || is_reference<_Tp>::value    >
   struct __add_const             {typedef _Tp type;};
 
@@ -729,6 +742,9 @@ namespace std {
   template <class _Tp> struct  remove_const            {typedef _Tp type;};
   template <class _Tp> struct  remove_const<const _Tp> {typedef _Tp type;};
 
+  template< class T >
+  using remove_const_t = typename remove_const<T>::type;
+
   template <class _Tp> struct  add_lvalue_reference    {typedef _Tp& type;};
 
   template <class _Tp> struct is_trivially_copy_assignable
@@ -793,6 +809,9 @@ namespace std {
     return __result;
   }
 
+  template< bool B, class T = void >
+  using enable_if_t = typename enable_if<B,T>::type;
+
   template<class InputIter, class OutputIter>
   OutputIter copy_backward(InputIter II, InputIter IE, OutputIter OI) {
     return __copy_backward(II, IE, OI);
@@ -1252,4 +1271,107 @@ template <typename Ret, typename... Args> class packaged_task<Ret(Args...)> {
   // TODO: Add some actual implementation.
 };
 
+  #if __cplusplus >= 201703L
+
+  namespace detail
+  {
+    template<class T>
+    struct type_identity { using type = T; }; // or use std::type_identity (since C++20)
+ 
+    template<class T>
+    auto try_add_pointer(int) -> type_identity<typename std::remove_reference<T>::type*>;
+    template<class T>
+    auto try_add_pointer(...) -> type_identity<T>;
+  } // namespace detail
+ 
+  template<class T>
+  struct add_pointer : decltype(detail::try_add_pointer<T>(0)) {};
+
+  template< class T >
+  using add_pointer_t = typename add_pointer<T>::type;
+
+  template<class T> struct remove_cv { typedef T type; };
+  template<class T> struct remove_cv<const T> { typedef T type; };
+  template<class T> struct remove_cv<volatile T> { typedef T type; };
+  template<class T> struct remove_cv<const volatile T> { typedef T type; };
+
+  template< class T >
+  using remove_cv_t = typename remove_cv<T>::type;
+
+  // This decay does not behave exactly like std::decay, but this is enough
+  // for testing the std::variant checker
+  template<class T>
+  struct decay{typedef remove_cv_t<remove_reference_t<T>> type;};
+  template<class T>
+  using decay_t = typename decay<T>::type;
+  
+  // variant
+  template <class... Types> class variant;
+  // variant helper classes
+  template <class T> struct variant_size;
+  template <class T> struct variant_size<const T>;
+  template <class T> struct variant_size<volatile T>;
+  template <class T> struct variant_size<const volatile T>;
+  template <class T> inline constexpr size_t variant_size_v = variant_size<T>::value;
+  template <class... Types>
+  struct variant_size<variant<Types...>>;
+  template <size_t I, class T> struct variant_alternative;
+  template <size_t I, class T> struct variant_alternative<I, const T>;
+  template <size_t I, class T> struct variant_alternative<I, volatile T>;
+  template <size_t I, class T> struct variant_alternative<I, const volatile T>;
+  template <size_t I, class T>
+  using variant_alternative_t = typename variant_alternative<I, T>::type;
+  template <size_t I, class... Types>
+  struct variant_alternative<I, variant<Types...>>;
+  inline constexpr size_t variant_npos = -1;
+  template <size_t I, class... Types>
+  constexpr variant_alternative_t<I, variant<Types...>>&
+    get(variant<Types...>&);
+  template <size_t I, class... Types>
+  constexpr variant_alternative_t<I, variant<Types...>>&&
+    get(variant<Types...>&&);
+  template <size_t I, class... Types>
+  constexpr const variant_alternative_t<I, variant<Types...>>&
+    get(const variant<Types...>&);
+  template <size_t I, class... Types>
+  constexpr const variant_alternative_t<I, variant<Types...>>&&
+    get(const variant<Types...>&&);
+  template <class T, class... Types>
+  constexpr T& get(variant<Types...>&);
+  template <class T, class... Types>
+  constexpr T&& get(variant<Types...>&&);
+  template <class T, class... Types>
+  constexpr const T& get(const variant<Types...>&);
+  template <class T, class... Types>
+  constexpr const T&& get(const variant<Types...>&&);
+  template <size_t I, class... Types>
+  constexpr add_pointer_t<variant_alternative_t<I, variant<Types...>>>
+    get_if(variant<Types...>*) noexcept;
+  template <size_t I, class... Types>
+  constexpr add_pointer_t<const variant_alternative_t<I, variant<Types...>>>
+    get_if(const variant<Types...>*) noexcept;
+  template <class T, class... Types>
+  constexpr add_pointer_t<T> get_if(variant<Types...>*) noexcept;
+  template <class T, class... Types>
+  constexpr add_pointer_t<const T> get_if(const variant<Types...>*) noexcept;
+
+  template <class... Types>
+  class variant {
+  public:
+    // constructors
+    constexpr variant()= default ;
+    constexpr variant(const variant&);
+    constexpr variant(variant&&);
+    template<typename T,
+            typename = std::enable_if_t<!is_same_v<std::variant<Types...>, decay_t<T>>>>
+	  constexpr variant(T&&);
+    // assignment
+    variant& operator=(const variant&);
+    variant& operator=(variant&&) ;
+    template<typename T,
+            typename = std::enable_if_t<!is_same_v<std::variant<Types...>, decay_t<T>>>>
+    variant& operator=(T&&);
+  };
+  #endif
+
 } // namespace std
diff --git a/clang/test/Analysis/diagnostics/explicit-suppression.cpp b/clang/test/Analysis/diagnostics/explicit-suppression.cpp
index b98d0260b096594..24586e37fe207a5 100644
--- a/clang/test/Analysis/diagnostics/explicit-suppression.cpp
+++ b/clang/test/Analysis/diagnostics/explicit-suppression.cpp
@@ -19,6 +19,6 @@ class C {
 void testCopyNull(C *I, C *E) {
   std::copy(I, E, (C *)0);
 #ifndef SUPPRESSED
-  // expected-warning at ../Inputs/system-header-simulator-cxx.h:741 {{Called C++ object pointer is null}}
+  // expected-warning at ../Inputs/system-header-simulator-cxx.h:757 {{Called C++ object pointer is null}}
 #endif
 }
diff --git a/clang/test/Analysis/std-variant-checker.cpp b/clang/test/Analysis/std-variant-checker.cpp
new file mode 100644
index 000000000000000..7f136c06b19cc60
--- /dev/null
+++ b/clang/test/Analysis/std-variant-checker.cpp
@@ -0,0 +1,358 @@
+// RUN: %clang %s -std=c++17 -Xclang -verify --analyze \
+// RUN:   -Xclang -analyzer-checker=core \
+// RUN:   -Xclang -analyzer-checker=debug.ExprInspection \
+// RUN:   -Xclang -analyzer-checker=core,alpha.core.StdVariant
+
+#include "Inputs/system-header-simulator-cxx.h"
+
+class Foo{};
+
+void clang_analyzer_warnIfReached();
+void clang_analyzer_eval(int);
+
+//helper functions
+void changeVariantType(std::variant<int, char> &v) {
+  v = 25;
+}
+
+void changesToInt(std::variant<int, char> &v);
+void changesToInt(std::variant<int, char> *v);
+
+void cannotChangePtr(const std::variant<int, char> &v);
+void cannotChangePtr(const std::variant<int, char> *v);
+
+char getUnknownChar();
+
+void swap(std::variant<int, char> &v1, std::variant<int, char> &v2) {
+  std::variant<int, char> tmp = v1;
+  v1 = v2;
+  v2 = tmp;
+}
+
+void cantDo(const std::variant<int, char>& v) {
+  std::variant<int, char> vtmp = v;
+  vtmp = 5;
+  int a = std::get<int> (vtmp);
+  (void) a;
+}
+
+void changeVariantPtr(std::variant<int, char> *v) {
+  *v = 'c';
+}
+
+using var_t = std::variant<int, char>;
+using var_tt = var_t;
+using int_t = int;
+using char_t = char;
+
+// A quick sanity check to see that std::variant's std::get
+// is not being confused with std::pairs std::get.
+void wontConfuseStdGets() {
+  std::pair<int, char> p{15, '1'};
+  int a = std::get<int>(p);
+  char c = std::get<char>(p);
+  (void)a;
+  (void)c;
+}
+
+//----------------------------------------------------------------------------//
+// std::get
+//----------------------------------------------------------------------------//
+void stdGetType() {
+  std::variant<int, char> v = 25;
+  int a = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void stdGetPointer() {
+  int *p = new int;
+  std::variant<int*, char> v = p;
+  int *a = std::get<int*>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int *', not a 'char'}}
+  (void)a;
+  (void)c;
+  delete p;
+}
+
+void stdGetObject() {
+  std::variant<int, char, Foo> v = Foo{};
+  Foo f = std::get<Foo>(v);
+  int i = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'Foo', not an 'int'}}
+  (void)i;
+}
+
+void stdGetPointerAndPointee() {
+  int a = 5;
+  std::variant<int, int*> v = &a;
+  int *b = std::get<int*>(v);
+  int c = std::get<int>(v); // expected-warning {{std::variant 'v' held an 'int *', not an 'int'}}
+  (void)c;
+  (void)b;
+}
+
+void variantHoldingVariant() {
+  std::variant<std::variant<int, char>, std::variant<char, int>> v = std::variant<int,char>(25);
+  std::variant<int, char> v1 = std::get<std::variant<int,char>>(v);
+  std::variant<char, int> v2 = std::get<std::variant<char,int>>(v); // expected-warning {{std::variant 'v' held a 'std::variant<int, char>', not a 'class std::variant<char, int>'}}
+}
+
+//----------------------------------------------------------------------------//
+// Constructors and assignments
+//----------------------------------------------------------------------------//
+void copyConstructor() {
+  std::variant<int, char> v = 25;
+  std::variant<int, char> t(v);
+  int a = std::get<int> (t);
+  char c = std::get<char> (t); // expected-warning {{std::variant 't' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void copyAssignmentOperator() {
+  std::variant<int, char> v = 25;
+  std::variant<int, char> t = 'c';
+  t = v;
+  int a = std::get<int> (t);
+  char c = std::get<char> (t); // expected-warning {{std::variant 't' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void assignmentOperator() {
+  std::variant<int, char> v = 25;
+  int a = std::get<int> (v);
+  (void)a;
+  v = 'c';
+  char c = std::get<char>(v);
+  a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void typeChangeThreeTimes() {
+  std::variant<int, char, float> v = 25;
+  int a = std::get<int> (v);
+  (void)a;
+  v = 'c';
+  char c = std::get<char>(v);
+  v = 25;
+  a = std::get<int>(v);
+  (void)a;
+  v = 1.25f;
+  float f = std::get<float>(v);
+  a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'float', not an 'int'}}
+  (void)a;
+  (void)c;
+  (void)f;
+}
+
+void defaultConstructor() {
+  std::variant<int, char> v;
+  int i = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)i;
+  (void)c;
+}
+
+// Verify that we handle temporary objects correctly
+void temporaryObjectsConstructor() {
+  std::variant<int, char> v(std::variant<int, char>('c'));
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void temporaryObjectsAssignment() {
+  std::variant<int, char> v = std::variant<int, char>('c');
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+// Verify that we handle pointer types correctly
+void pointerTypeHeld() {
+  int *p = new int;
+  std::variant<int*, char> v = p;
+  int *a = std::get<int*>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int *', not a 'char'}}
+  (void)a;
+  (void)c;
+  delete p;
+}
+
+std::variant<int, char> get_unknown_variant();
+// Verify that the copy constructor is handles properly when the std::variant
+// has no previously activated type and we copy an object of unknown value in it.
+void copyFromUnknownVariant() {
+  std::variant<int, char> u = get_unknown_variant();
+  std::variant<int, char> v(u);
+  int a = std::get<int>(v); // no-waring
+  char c = std::get<char>(v); // no-warning
+  (void)a;
+  (void)c;
+}
+
+// Verify that the copy constructor is handles properly when the std::variant
+// has previously activated type and we copy an object of unknown value in it.
+void copyFromUnknownVariantBef() {
+  std::variant<int, char> v = 25;
+  std::variant<int, char> u = get_unknown_variant();
+  v = u;
+  int a = std::get<int>(v); // no-waring
+  char c = std::get<char>(v); // no-warning
+  (void)a;
+  (void)c;
+}
+
+//----------------------------------------------------------------------------//
+// typedef
+//----------------------------------------------------------------------------//
+
+void typefdefedVariant() {
+  var_t v = 25;
+  int a = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void typedefedTypedfefedVariant() {
+  var_tt v = 25;
+  int a = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void typedefedGet() {
+  std::variant<char, int> v = 25;
+  int a = std::get<int_t>(v);
+  char c = std::get<char_t>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void typedefedPack() {
+  std::variant<int_t, char_t> v = 25;
+  int a = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void fromVariable() {
+  char o = 'c';
+  std::variant<int, char> v(o);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void unknowValueButKnownType() {
+  char o = getUnknownChar();
+  std::variant<int, char> v(o);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void createPointer() {
+  std::variant<int, char> *v = new std::variant<int, char>(15);
+  int a = std::get<int>(*v);
+  char c = std::get<char>(*v); // expected-warning {{std::variant  held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+  delete v;
+}
+
+//----------------------------------------------------------------------------//
+// Passing std::variants to functions
+//----------------------------------------------------------------------------//
+
+// Verifying that we are not invalidating the memory region of a variant if
+// a non inlined or inlined function takes it as a constant reference or pointer
+void constNonInlineRef() {
+  std::variant<int, char> v = 'c';
+  cannotChangePtr(v);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void contNonInlinePtr() {
+  std::variant<int, char> v = 'c';
+  cannotChangePtr(&v);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void copyInAFunction() {
+  std::variant<int, char> v = 'c';
+  cantDo(v);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+
+}
+
+// Verifying that we can keep track of the type stored in std::variant when
+// it is passed to an inlined function as a reference or pointer
+void changeThruPointers() {
+  std::variant<int, char> v = 15;
+  changeVariantPtr(&v);
+  char c = std::get<char> (v);
+  int a = std::get<int> (v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void functionCallWithCopyAssignment() {
+  var_t v1 = 15;
+  var_t v2 = 'c';
+  swap(v1, v2);
+  int a = std::get<int> (v2);
+  (void)a;
+  char c = std::get<char> (v1);
+  a = std::get<int> (v1); // expected-warning {{std::variant 'v1' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void inlineFunctionCall() {
+  std::variant<int, char> v = 'c';
+  changeVariantType(v);
+  int a = std::get<int> (v);
+  char c = std::get<char> (v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+// Verifying that we invalidate the mem region of std::variant when it is
+// passed as a non const reference or a pointer to a non inlined function.
+void nonInlineFunctionCall() {
+  std::variant<int, char> v = 'c';
+  changesToInt(v);
+  int a = std::get<int> (v); // no-waring
+  char c = std::get<char> (v); // no-warning
+  (void)a;
+  (void)c;
+}
+
+void nonInlineFunctionCallPtr() {
+  std::variant<int, char> v = 'c';
+  changesToInt(&v);
+  int a = std::get<int> (v); // no-warning
+  char c = std::get<char> (v); // no-warning
+  (void)a;
+  (void)c;
+}
\ No newline at end of file

>From 894ea8fb15a1a7f81e68aa48b3c6eaca9da86bfb Mon Sep 17 00:00:00 2001
From: Gabor Spaits <gabor.spaits at ericsson.com>
Date: Sun, 22 Oct 2023 16:36:38 +0200
Subject: [PATCH 2/5] [NFC] Namespace, comment

---
 .../Checkers/StdVariantChecker.cpp             |  3 ---
 .../Checkers/TaggedUnionModeling.h             | 18 +++++++++---------
 2 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
index 680c5567431bbfb..52ab2dfbac9c9e6 100644
--- a/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
@@ -33,9 +33,6 @@ namespace clang {
 namespace ento {
 namespace tagged_union_modeling {
 
-// Returns the CallEvent representing the caller of the function
-// It is needed because the CallEvent class does not contain enough information
-// to tell who called it. Checker context is needed.
 CallEventRef<> getCaller(const CallEvent &Call, const ProgramStateRef &State) {
   const auto *CallLocationContext = Call.getLocationContext();
   if (!CallLocationContext || CallLocationContext->inTopFrame())
diff --git a/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h b/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
index 593f243e84ca686..6f66a96d7aab88a 100644
--- a/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
+++ b/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef LLVM_CLANG_LIB_STATICANALYZER_CHECKER_VARIANTLIKETYPEMODELING_H
-#define LLVM_CLANG_LIB_STATICANALYZER_CHECKER_VARIANTLIKETYPEMODELING_H
+#ifndef LLVM_CLANG_LIB_STATICANALYZER_CHECKERS_TAGGEDUNIONMODELING_H
+#define LLVM_CLANG_LIB_STATICANALYZER_CHECKERS_TAGGEDUNIONMODELING_H
 
 #include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
 #include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
@@ -19,12 +19,14 @@
 #include "llvm/ADT/FoldingSet.h"
 #include <numeric>
 
-namespace clang {
-namespace ento {
-namespace tagged_union_modeling {
+namespace clang::ento::tagged_union_modeling {
 
 // The implementation of all these functions can be found in the file
 // StdVariantChecker.cpp under the same directory as this file.
+
+// Returns the CallEvent representing the caller of the function
+// It is needed because the CallEvent class does not contain enough information
+// to tell who called it. Checker context is needed.
 CallEventRef<> getCaller(const CallEvent &Call, CheckerContext &C);
 bool isCopyConstructorCall(const CallEvent &Call);
 bool isCopyAssignmentCall(const CallEvent &Call);
@@ -97,8 +99,6 @@ void handleConstructorAndAssignment(const CallEvent &Call, CheckerContext &C,
   C.addTransition(State);
 }
 
-} // namespace tagged_union_modeling
-} // namespace ento
-} // namespace clang
+} // namespace clang::ento::tagged_union_modeling
 
-#endif // LLVM_CLANG_LIB_STATICANALYZER_CHECKER_VARIANTLIKETYPEMODELING_H
\ No newline at end of file
+#endif // LLVM_CLANG_LIB_STATICANALYZER_CHECKERS_TAGGEDUNIONMODELING_H
\ No newline at end of file

>From caf778112694848194587fd981219523f3b1afae Mon Sep 17 00:00:00 2001
From: Gabor Spaits <gabor.spaits at ericsson.com>
Date: Sun, 22 Oct 2023 18:17:46 +0200
Subject: [PATCH 3/5] Move functions into CallEvent class

---
 .../Core/PathSensitive/CallEvent.h            |  8 ++++
 .../Checkers/StdVariantChecker.cpp            | 40 ++++---------------
 .../Checkers/TaggedUnionModeling.h            |  9 +----
 clang/lib/StaticAnalyzer/Core/CallEvent.cpp   | 20 ++++++++++
 4 files changed, 37 insertions(+), 40 deletions(-)

diff --git a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h
index 8129ebc8fdc6937..3aa27fd1df6532b 100644
--- a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h
+++ b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h
@@ -455,6 +455,14 @@ class CallEvent {
   /// If the call returns a C++ record type then the region of its return value
   /// can be retrieved from its construction context.
   std::optional<SVal> getReturnValueUnderConstruction() const;
+  
+  // Returns the CallEvent representing the caller of this function
+  const CallEventRef<> getCaller() const;
+
+  // Returns true if the function was called from a standard library function.
+  // If not or could not get the caller (it may be a top level function)
+  // returns false.
+  bool calledFromSystemHeader() const;
 
   // Iterator access to formal parameters and their types.
 private:
diff --git a/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
index 52ab2dfbac9c9e6..8ddf7d7b2f9dbe3 100644
--- a/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
@@ -29,22 +29,7 @@ using namespace tagged_union_modeling;
 
 REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType)
 
-namespace clang {
-namespace ento {
-namespace tagged_union_modeling {
-
-CallEventRef<> getCaller(const CallEvent &Call, const ProgramStateRef &State) {
-  const auto *CallLocationContext = Call.getLocationContext();
-  if (!CallLocationContext || CallLocationContext->inTopFrame())
-    return nullptr;
-
-  const auto *CallStackFrameContext = CallLocationContext->getStackFrame();
-  if (!CallStackFrameContext)
-    return nullptr;
-
-  CallEventManager &CEMgr = State->getStateManager().getCallEventManager();
-  return CEMgr.getCaller(CallStackFrameContext, State);
-}
+namespace clang::ento::tagged_union_modeling {
 
 const CXXConstructorDecl *
 getConstructorDeclarationForCall(const CallEvent &Call) {
@@ -102,21 +87,7 @@ bool isStdVariant(const Type *Type) {
   return isStdType(Type, llvm::StringLiteral("variant"));
 }
 
-bool calledFromSystemHeader(const CallEvent &Call,
-                            const ProgramStateRef &State) {
-  if (CallEventRef<> Caller = getCaller(Call, State))
-    return Caller->isInSystemHeader();
-
-  return false;
-}
-
-bool calledFromSystemHeader(const CallEvent &Call, CheckerContext &C) {
-  return calledFromSystemHeader(Call, C.getState());
-}
-
-} // end of namespace tagged_union_modeling
-} // end of namespace ento
-} // end of namespace clang
+} // end of namespace clang::ento::tagged_union_modeling
 
 static std::optional<ArrayRef<TemplateArgument>>
 getTemplateArgsFromVariant(const Type *VariantType) {
@@ -171,14 +142,17 @@ class StdVariantChecker : public Checker<eval::Call, check::RegionChanges> {
                                      ArrayRef<const MemRegion *> Regions,
                                      const LocationContext *,
                                      const CallEvent *Call) const {
+    if (!Call)
+      return State;
+
     return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
-        Call, State, Regions);
+        *Call, State, Regions);
   }
 
   bool evalCall(const CallEvent &Call, CheckerContext &C) const {
     // Check if the call was not made from a system header. If it was then
     // we do an early return because it is part of the implementation.
-    if (calledFromSystemHeader(Call, C))
+    if (Call.calledFromSystemHeader())
       return false;
 
     if (StdGet.matches(Call))
diff --git a/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h b/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
index 6f66a96d7aab88a..557e8a76506e611 100644
--- a/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
+++ b/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
@@ -24,29 +24,24 @@ namespace clang::ento::tagged_union_modeling {
 // The implementation of all these functions can be found in the file
 // StdVariantChecker.cpp under the same directory as this file.
 
-// Returns the CallEvent representing the caller of the function
-// It is needed because the CallEvent class does not contain enough information
-// to tell who called it. Checker context is needed.
-CallEventRef<> getCaller(const CallEvent &Call, CheckerContext &C);
 bool isCopyConstructorCall(const CallEvent &Call);
 bool isCopyAssignmentCall(const CallEvent &Call);
 bool isMoveAssignmentCall(const CallEvent &Call);
 bool isMoveConstructorCall(const CallEvent &Call);
 bool isStdType(const Type *Type, const std::string &TypeName);
 bool isStdVariant(const Type *Type);
-bool calledFromSystemHeader(const CallEvent &Call, CheckerContext &C);
 
 // When invalidating regions, we also have to follow that by invalidating the
 // corresponding custom data in the program state.
 template <class TypeMap>
 ProgramStateRef
-removeInformationStoredForDeadInstances(const CallEvent *Call,
+removeInformationStoredForDeadInstances(const CallEvent &Call,
                                         ProgramStateRef State,
                                         ArrayRef<const MemRegion *> Regions) {
   // If we do not know anything about the call we shall not continue.
   // If the call is happens within a system header it is implementation detail.
   // We should not take it into consideration.
-  if (!Call || Call->isInSystemHeader())
+  if (Call.isInSystemHeader())
     return State;
 
   for (const MemRegion *Region : Regions)
diff --git a/clang/lib/StaticAnalyzer/Core/CallEvent.cpp b/clang/lib/StaticAnalyzer/Core/CallEvent.cpp
index ad5bb66c4fff3c8..d1c456a796b2a20 100644
--- a/clang/lib/StaticAnalyzer/Core/CallEvent.cpp
+++ b/clang/lib/StaticAnalyzer/Core/CallEvent.cpp
@@ -517,6 +517,26 @@ const ConstructionContext *CallEvent::getConstructionContext() const {
   return nullptr;
 }
 
+const CallEventRef<> CallEvent::getCaller() const {
+  const auto *CallLocationContext = this->getLocationContext();
+  if (!CallLocationContext || CallLocationContext->inTopFrame())
+    return nullptr;
+
+  const auto *CallStackFrameContext = CallLocationContext->getStackFrame();
+  if (!CallStackFrameContext)
+    return nullptr;
+
+  CallEventManager &CEMgr = State->getStateManager().getCallEventManager();
+  return CEMgr.getCaller(CallStackFrameContext, State);
+}
+
+bool CallEvent::calledFromSystemHeader() const {
+  if (const CallEventRef<> Caller = getCaller())
+    return Caller->isInSystemHeader();
+
+  return false;
+}
+
 std::optional<SVal> CallEvent::getReturnValueUnderConstruction() const {
   const auto *CC = getConstructionContext();
   if (!CC)

>From 1042ac6fdde192aa7fa1a74f707a4fede9861b38 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?G=C3=A1bor=20Spaits?=
 <48805437+spaits at users.noreply.github.com>
Date: Sun, 22 Oct 2023 18:21:21 +0200
Subject: [PATCH 4/5] Add extra line break for checker definiton

Co-authored-by: Balazs Benics <benicsbalazs at gmail.com>
---
 clang/include/clang/StaticAnalyzer/Checkers/Checkers.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td b/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
index a93e97348606f28..3d368df0abddad2 100644
--- a/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
+++ b/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
@@ -320,7 +320,7 @@ def C11LockChecker : Checker<"C11Lock">,
 
 def StdVariantChecker : Checker<"StdVariant">,
   HelpText<"Check for bad type access for std::variant.">,
-  Documentation<NotDocumented>;
+  Documentation<Documented>;
 
 } // end "alpha.core"
 

>From 188d36e4fe5e49f0531d154f5c21866b3c36f47c Mon Sep 17 00:00:00 2001
From: Gabor Spaits <gabor.spaits at ericsson.com>
Date: Sun, 22 Oct 2023 19:10:06 +0200
Subject: [PATCH 5/5] Format changes

---
 .../Core/PathSensitive/CallEvent.h            | 93 +++++++++----------
 1 file changed, 43 insertions(+), 50 deletions(-)

diff --git a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h
index 3aa27fd1df6532b..e2d72a80d37a51d 100644
--- a/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h
+++ b/clang/include/clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h
@@ -78,7 +78,7 @@ enum CallEventKind {
 
 class CallEvent;
 
-template<typename T = CallEvent>
+template <typename T = CallEvent>
 class CallEventRef : public IntrusiveRefCntPtr<const T> {
 public:
   CallEventRef(const T *Call) : IntrusiveRefCntPtr<const T>(Call) {}
@@ -94,8 +94,7 @@ class CallEventRef : public IntrusiveRefCntPtr<const T> {
 
   // Allow implicit conversions to a superclass type, since CallEventRef
   // behaves like a pointer-to-const.
-  template <typename SuperT>
-  operator CallEventRef<SuperT> () const {
+  template <typename SuperT> operator CallEventRef<SuperT>() const {
     return this->get();
   }
 };
@@ -124,9 +123,9 @@ class RuntimeDefinition {
 
 public:
   RuntimeDefinition() = default;
-  RuntimeDefinition(const Decl *InD): D(InD) {}
+  RuntimeDefinition(const Decl *InD) : D(InD) {}
   RuntimeDefinition(const Decl *InD, bool Foreign) : D(InD), Foreign(Foreign) {}
-  RuntimeDefinition(const Decl *InD, const MemRegion *InR): D(InD), R(InR) {}
+  RuntimeDefinition(const Decl *InD, const MemRegion *InR) : D(InD), R(InR) {}
 
   const Decl *getDecl() { return D; }
   bool isForeign() const { return Foreign; }
@@ -207,8 +206,9 @@ class CallEvent {
 
   /// Used to specify non-argument regions that will be invalidated as a
   /// result of this call.
-  virtual void getExtraInvalidatedValues(ValueList &Values,
-                 RegionAndSymbolInvalidationTraits *ETraits) const {}
+  virtual void
+  getExtraInvalidatedValues(ValueList &Values,
+                            RegionAndSymbolInvalidationTraits *ETraits) const {}
 
 public:
   CallEvent &operator=(const CallEvent &) = delete;
@@ -231,14 +231,10 @@ class CallEvent {
   void setForeign(bool B) const { Foreign = B; }
 
   /// The state in which the call is being evaluated.
-  const ProgramStateRef &getState() const {
-    return State;
-  }
+  const ProgramStateRef &getState() const { return State; }
 
   /// The context in which the call is being evaluated.
-  const LocationContext *getLocationContext() const {
-    return LCtx;
-  }
+  const LocationContext *getLocationContext() const { return LCtx; }
 
   const CFGBlock::ConstCFGElementRef &getCFGElementRef() const {
     return ElemRef;
@@ -270,7 +266,7 @@ class CallEvent {
     SourceLocation Loc = D->getLocation();
     if (Loc.isValid()) {
       const SourceManager &SM =
-        getState()->getStateManager().getContext().getSourceManager();
+          getState()->getStateManager().getContext().getSourceManager();
       return SM.isInSystemHeader(D->getLocation());
     }
 
@@ -324,9 +320,7 @@ class CallEvent {
   // NOTE: The exact semantics of this are still being defined!
   // We don't really want a list of hardcoded exceptions in the long run,
   // but we don't want duplicated lists of known APIs in the short term either.
-  virtual bool argumentsMayEscape() const {
-    return hasNonZeroCallbackArg();
-  }
+  virtual bool argumentsMayEscape() const { return hasNonZeroCallbackArg(); }
 
   /// Returns true if the callee is an externally-visible function in the
   /// top-level namespace, such as \c malloc.
@@ -455,7 +449,7 @@ class CallEvent {
   /// If the call returns a C++ record type then the region of its return value
   /// can be retrieved from its construction context.
   std::optional<SVal> getReturnValueUnderConstruction() const;
-  
+
   // Returns the CallEvent representing the caller of this function
   const CallEventRef<> getCaller() const;
 
@@ -587,8 +581,9 @@ class BlockCall : public CallEvent {
 
   void cloneTo(void *Dest) const override { new (Dest) BlockCall(*this); }
 
-  void getExtraInvalidatedValues(ValueList &Values,
-         RegionAndSymbolInvalidationTraits *ETraits) const override;
+  void getExtraInvalidatedValues(
+      ValueList &Values,
+      RegionAndSymbolInvalidationTraits *ETraits) const override;
 
 public:
   const CallExpr *getOriginExpr() const override {
@@ -658,14 +653,12 @@ class BlockCall : public CallEvent {
     // the block body and analyze the operator() method on the captured lambda.
     const VarDecl *LambdaVD = getRegionStoringCapturedLambda()->getDecl();
     const CXXRecordDecl *LambdaDecl = LambdaVD->getType()->getAsCXXRecordDecl();
-    CXXMethodDecl* LambdaCallOperator = LambdaDecl->getLambdaCallOperator();
+    CXXMethodDecl *LambdaCallOperator = LambdaDecl->getLambdaCallOperator();
 
     return RuntimeDefinition(LambdaCallOperator);
   }
 
-  bool argumentsMayEscape() const override {
-    return true;
-  }
+  bool argumentsMayEscape() const override { return true; }
 
   void getInitialStackFrameContents(const StackFrameContext *CalleeCtx,
                                     BindingsTy &Bindings) const override;
@@ -692,8 +685,9 @@ class CXXInstanceCall : public AnyFunctionCall {
       : AnyFunctionCall(D, St, LCtx, ElemRef) {}
   CXXInstanceCall(const CXXInstanceCall &Other) = default;
 
-  void getExtraInvalidatedValues(ValueList &Values,
-         RegionAndSymbolInvalidationTraits *ETraits) const override;
+  void getExtraInvalidatedValues(
+      ValueList &Values,
+      RegionAndSymbolInvalidationTraits *ETraits) const override;
 
 public:
   /// Returns the expression representing the implicit 'this' object.
@@ -851,7 +845,9 @@ class CXXDestructorCall : public CXXInstanceCall {
 
   CXXDestructorCall(const CXXDestructorCall &Other) = default;
 
-  void cloneTo(void *Dest) const override {new (Dest) CXXDestructorCall(*this);}
+  void cloneTo(void *Dest) const override {
+    new (Dest) CXXDestructorCall(*this);
+  }
 
 public:
   SourceRange getSourceRange() const override { return Location; }
@@ -888,8 +884,9 @@ class AnyCXXConstructorCall : public AnyFunctionCall {
     Data = Target;
   }
 
-  void getExtraInvalidatedValues(ValueList &Values,
-         RegionAndSymbolInvalidationTraits *ETraits) const override;
+  void getExtraInvalidatedValues(
+      ValueList &Values,
+      RegionAndSymbolInvalidationTraits *ETraits) const override;
 
   void getInitialStackFrameContents(const StackFrameContext *CalleeCtx,
                                     BindingsTy &Bindings) const override;
@@ -929,7 +926,9 @@ class CXXConstructorCall : public AnyCXXConstructorCall {
 
   CXXConstructorCall(const CXXConstructorCall &Other) = default;
 
-  void cloneTo(void *Dest) const override { new (Dest) CXXConstructorCall(*this); }
+  void cloneTo(void *Dest) const override {
+    new (Dest) CXXConstructorCall(*this);
+  }
 
 public:
   const CXXConstructExpr *getOriginExpr() const override {
@@ -1048,7 +1047,9 @@ class CXXAllocatorCall : public AnyFunctionCall {
       : AnyFunctionCall(E, St, LCtx, ElemRef) {}
   CXXAllocatorCall(const CXXAllocatorCall &Other) = default;
 
-  void cloneTo(void *Dest) const override { new (Dest) CXXAllocatorCall(*this); }
+  void cloneTo(void *Dest) const override {
+    new (Dest) CXXAllocatorCall(*this);
+  }
 
 public:
   const CXXNewExpr *getOriginExpr() const override {
@@ -1162,11 +1163,7 @@ class CXXDeallocatorCall : public AnyFunctionCall {
 //
 // Note to maintainers: OCM_Message should always be last, since it does not
 // need to fit in the Data field's low bits.
-enum ObjCMessageKind {
-  OCM_PropertyAccess,
-  OCM_Subscript,
-  OCM_Message
-};
+enum ObjCMessageKind { OCM_PropertyAccess, OCM_Subscript, OCM_Message };
 
 /// Represents any expression that calls an Objective-C method.
 ///
@@ -1188,8 +1185,9 @@ class ObjCMethodCall : public CallEvent {
 
   void cloneTo(void *Dest) const override { new (Dest) ObjCMethodCall(*this); }
 
-  void getExtraInvalidatedValues(ValueList &Values,
-         RegionAndSymbolInvalidationTraits *ETraits) const override;
+  void getExtraInvalidatedValues(
+      ValueList &Values,
+      RegionAndSymbolInvalidationTraits *ETraits) const override;
 
   /// Check if the selector may have multiple definitions (may have overrides).
   virtual bool canBeOverridenInSubclass(ObjCInterfaceDecl *IDecl,
@@ -1204,9 +1202,7 @@ class ObjCMethodCall : public CallEvent {
     return getOriginExpr()->getMethodDecl();
   }
 
-  unsigned getNumArgs() const override {
-    return getOriginExpr()->getNumArgs();
-  }
+  unsigned getNumArgs() const override { return getOriginExpr()->getNumArgs(); }
 
   const Expr *getArgExpr(unsigned Index) const override {
     return getOriginExpr()->getArg(Index);
@@ -1220,9 +1216,7 @@ class ObjCMethodCall : public CallEvent {
     return getOriginExpr()->getMethodFamily();
   }
 
-  Selector getSelector() const {
-    return getOriginExpr()->getSelector();
-  }
+  Selector getSelector() const { return getOriginExpr()->getSelector(); }
 
   SourceRange getSourceRange() const override;
 
@@ -1270,7 +1264,7 @@ class ObjCMethodCall : public CallEvent {
   void getInitialStackFrameContents(const StackFrameContext *CalleeCtx,
                                     BindingsTy &Bindings) const override;
 
-  ArrayRef<ParmVarDecl*> parameters() const override;
+  ArrayRef<ParmVarDecl *> parameters() const override;
 
   Kind getKind() const override { return CE_ObjCMessage; }
   StringRef getKindAsString() const override { return "ObjCMethodCall"; }
@@ -1344,8 +1338,8 @@ class CallEventManager {
   CallEventManager(llvm::BumpPtrAllocator &alloc) : Alloc(alloc) {}
 
   /// Gets an outside caller given a callee context.
-  CallEventRef<>
-  getCaller(const StackFrameContext *CalleeCtx, ProgramStateRef State);
+  CallEventRef<> getCaller(const StackFrameContext *CalleeCtx,
+                           ProgramStateRef State);
 
   /// Gets a call event for a function call, Objective-C method call,
   /// a 'new', or a 'delete' call.
@@ -1441,11 +1435,10 @@ inline void CallEvent::Release() const {
 namespace llvm {
 
 // Support isa<>, cast<>, and dyn_cast<> for CallEventRef.
-template<class T> struct simplify_type< clang::ento::CallEventRef<T>> {
+template <class T> struct simplify_type<clang::ento::CallEventRef<T>> {
   using SimpleType = const T *;
 
-  static SimpleType
-  getSimplifiedValue(clang::ento::CallEventRef<T> Val) {
+  static SimpleType getSimplifiedValue(clang::ento::CallEventRef<T> Val) {
     return Val.get();
   }
 };



More information about the cfe-commits mailing list