[compiler-rt] [analyzer] Add std::variant checker (PR #66481)

Gábor Spaits via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 22 03:33:42 PDT 2023


https://github.com/spaits updated https://github.com/llvm/llvm-project/pull/66481

>From ce62d3e1924b497b3e7160579a87557119c9e35d Mon Sep 17 00:00:00 2001
From: Gabor Spaits <gaborspaits1 at gmail.com>
Date: Fri, 15 Sep 2023 10:21:30 +0200
Subject: [PATCH] [analyzer] Add std::variant checker

Adding a checker that checks for bad std::variant type access.
---
 .../clang/StaticAnalyzer/Checkers/Checkers.td |   4 +
 .../StaticAnalyzer/Checkers/CMakeLists.txt    |   1 +
 .../Checkers/StdVariantChecker.cpp            | 327 ++++++++++++++++
 .../Checkers/TaggedUnionModeling.h            | 104 +++++
 .../Inputs/system-header-simulator-cxx.h      | 122 ++++++
 .../diagnostics/explicit-suppression.cpp      |   2 +-
 clang/test/Analysis/std-variant-checker.cpp   | 358 ++++++++++++++++++
 7 files changed, 917 insertions(+), 1 deletion(-)
 create mode 100644 clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
 create mode 100644 clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
 create mode 100644 clang/test/Analysis/std-variant-checker.cpp

diff --git a/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td b/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
index 65c1595eb6245dd..052cf5f884f2773 100644
--- a/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
+++ b/clang/include/clang/StaticAnalyzer/Checkers/Checkers.td
@@ -318,6 +318,10 @@ def C11LockChecker : Checker<"C11Lock">,
   Dependencies<[PthreadLockBase]>,
   Documentation<HasDocumentation>;
 
+def StdVariantChecker : Checker<"StdVariant">,
+  HelpText<"Check for bad type access for std::variant.">,
+  Documentation<NotDocumented>;
+
 } // end "alpha.core"
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt b/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt
index ae849f59f90d3d9..d7cb51e1a0819a8 100644
--- a/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt
+++ b/clang/lib/StaticAnalyzer/Checkers/CMakeLists.txt
@@ -108,6 +108,7 @@ add_clang_library(clangStaticAnalyzerCheckers
   SmartPtrModeling.cpp
   StackAddrEscapeChecker.cpp
   StdLibraryFunctionsChecker.cpp
+  StdVariantChecker.cpp
   STLAlgorithmModeling.cpp
   StreamChecker.cpp
   StringChecker.cpp
diff --git a/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp b/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
new file mode 100644
index 000000000000000..680c5567431bbfb
--- /dev/null
+++ b/clang/lib/StaticAnalyzer/Checkers/StdVariantChecker.cpp
@@ -0,0 +1,327 @@
+//===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/Type.h"
+#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
+#include "clang/StaticAnalyzer/Core/Checker.h"
+#include "clang/StaticAnalyzer/Core/CheckerManager.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h"
+#include "llvm/ADT/FoldingSet.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <optional>
+#include <string_view>
+
+#include "TaggedUnionModeling.h"
+
+using namespace clang;
+using namespace ento;
+using namespace tagged_union_modeling;
+
+REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType)
+
+namespace clang {
+namespace ento {
+namespace tagged_union_modeling {
+
+// Returns the CallEvent representing the caller of the function
+// It is needed because the CallEvent class does not contain enough information
+// to tell who called it. Checker context is needed.
+CallEventRef<> getCaller(const CallEvent &Call, const ProgramStateRef &State) {
+  const auto *CallLocationContext = Call.getLocationContext();
+  if (!CallLocationContext || CallLocationContext->inTopFrame())
+    return nullptr;
+
+  const auto *CallStackFrameContext = CallLocationContext->getStackFrame();
+  if (!CallStackFrameContext)
+    return nullptr;
+
+  CallEventManager &CEMgr = State->getStateManager().getCallEventManager();
+  return CEMgr.getCaller(CallStackFrameContext, State);
+}
+
+const CXXConstructorDecl *
+getConstructorDeclarationForCall(const CallEvent &Call) {
+  const auto *ConstructorCall = dyn_cast<CXXConstructorCall>(&Call);
+  if (!ConstructorCall)
+    return nullptr;
+
+  return ConstructorCall->getDecl();
+}
+
+bool isCopyConstructorCall(const CallEvent &Call) {
+  if (const CXXConstructorDecl *ConstructorDecl =
+          getConstructorDeclarationForCall(Call))
+    return ConstructorDecl->isCopyConstructor();
+  return false;
+}
+
+bool isCopyAssignmentCall(const CallEvent &Call) {
+  const Decl *CopyAssignmentDecl = Call.getDecl();
+
+  if (const auto *AsMethodDecl =
+          dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl))
+    return AsMethodDecl->isCopyAssignmentOperator();
+  return false;
+}
+
+bool isMoveConstructorCall(const CallEvent &Call) {
+  const CXXConstructorDecl *ConstructorDecl =
+      getConstructorDeclarationForCall(Call);
+  if (!ConstructorDecl)
+    return false;
+
+  return ConstructorDecl->isMoveConstructor();
+}
+
+bool isMoveAssignmentCall(const CallEvent &Call) {
+  const Decl *CopyAssignmentDecl = Call.getDecl();
+
+  const auto *AsMethodDecl =
+      dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl);
+  if (!AsMethodDecl)
+    return false;
+
+  return AsMethodDecl->isMoveAssignmentOperator();
+}
+
+bool isStdType(const Type *Type, llvm::StringRef TypeName) {
+  auto *Decl = Type->getAsRecordDecl();
+  if (!Decl)
+    return false;
+  return (Decl->getName() == TypeName) && Decl->isInStdNamespace();
+}
+
+bool isStdVariant(const Type *Type) {
+  return isStdType(Type, llvm::StringLiteral("variant"));
+}
+
+bool calledFromSystemHeader(const CallEvent &Call,
+                            const ProgramStateRef &State) {
+  if (CallEventRef<> Caller = getCaller(Call, State))
+    return Caller->isInSystemHeader();
+
+  return false;
+}
+
+bool calledFromSystemHeader(const CallEvent &Call, CheckerContext &C) {
+  return calledFromSystemHeader(Call, C.getState());
+}
+
+} // end of namespace tagged_union_modeling
+} // end of namespace ento
+} // end of namespace clang
+
+static std::optional<ArrayRef<TemplateArgument>>
+getTemplateArgsFromVariant(const Type *VariantType) {
+  const auto *TempSpecType = VariantType->getAs<TemplateSpecializationType>();
+  if (!TempSpecType)
+    return {};
+
+  return TempSpecType->template_arguments();
+}
+
+static std::optional<QualType>
+getNthTemplateTypeArgFromVariant(const Type *varType, unsigned i) {
+  std::optional<ArrayRef<TemplateArgument>> VariantTemplates =
+      getTemplateArgsFromVariant(varType);
+  if (!VariantTemplates)
+    return {};
+
+  return (*VariantTemplates)[i].getAsType();
+}
+
+static bool isVowel(char a) {
+  switch (a) {
+  case 'a':
+  case 'e':
+  case 'i':
+  case 'o':
+  case 'u':
+    return true;
+  default:
+    return false;
+  }
+}
+
+static llvm::StringRef indefiniteArticleBasedOnVowel(char a) {
+  if (isVowel(a))
+    return "an";
+  return "a";
+}
+
+class StdVariantChecker : public Checker<eval::Call, check::RegionChanges> {
+  // Call descriptors to find relevant calls
+  CallDescription VariantConstructor{{"std", "variant", "variant"}};
+  CallDescription VariantAssignmentOperator{{"std", "variant", "operator="}};
+  CallDescription StdGet{{"std", "get"}, 1, 1};
+
+  BugType BadVariantType{this, "BadVariantType", "BadVariantType"};
+
+public:
+  ProgramStateRef checkRegionChanges(ProgramStateRef State,
+                                     const InvalidatedSymbols *,
+                                     ArrayRef<const MemRegion *>,
+                                     ArrayRef<const MemRegion *> Regions,
+                                     const LocationContext *,
+                                     const CallEvent *Call) const {
+    return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
+        Call, State, Regions);
+  }
+
+  bool evalCall(const CallEvent &Call, CheckerContext &C) const {
+    // Check if the call was not made from a system header. If it was then
+    // we do an early return because it is part of the implementation.
+    if (calledFromSystemHeader(Call, C))
+      return false;
+
+    if (StdGet.matches(Call))
+      return handleStdGetCall(Call, C);
+
+    // First check if a constructor call is happening. If it is a
+    // constructor call, check if it is an std::variant constructor call.
+    bool IsVariantConstructor =
+        isa<CXXConstructorCall>(Call) && VariantConstructor.matches(Call);
+    bool IsVariantAssignmentOperatorCall =
+        isa<CXXMemberOperatorCall>(Call) &&
+        VariantAssignmentOperator.matches(Call);
+
+    if (IsVariantConstructor || IsVariantAssignmentOperatorCall) {
+      if (Call.getNumArgs() == 0 && IsVariantConstructor) {
+        handleDefaultConstructor(cast<CXXConstructorCall>(&Call), C);
+        return true;
+      }
+
+      // FIXME Later this checker should be extended to handle constructors
+      // with multiple arguments.
+      if (Call.getNumArgs() != 1)
+        return false;
+
+      SVal ThisSVal;
+      if (IsVariantConstructor) {
+        const auto &AsConstructorCall = cast<CXXConstructorCall>(Call);
+        ThisSVal = AsConstructorCall.getCXXThisVal();
+      } else if (IsVariantAssignmentOperatorCall) {
+        const auto &AsMemberOpCall = cast<CXXMemberOperatorCall>(Call);
+        ThisSVal = AsMemberOpCall.getCXXThisVal();
+      } else {
+        return false;
+      }
+
+      handleConstructorAndAssignment<VariantHeldTypeMap>(Call, C, ThisSVal);
+      return true;
+    }
+    return false;
+  }
+
+private:
+  // The default constructed std::variant must be handled separately
+  // by default the std::variant is going to hold a default constructed instance
+  // of the first type of the possible types
+  void handleDefaultConstructor(const CXXConstructorCall *ConstructorCall,
+                                CheckerContext &C) const {
+    SVal ThisSVal = ConstructorCall->getCXXThisVal();
+
+    const auto *const ThisMemRegion = ThisSVal.getAsRegion();
+    if (!ThisMemRegion)
+      return;
+
+    std::optional<QualType> DefaultType = getNthTemplateTypeArgFromVariant(
+        ThisSVal.getType(C.getASTContext())->getPointeeType().getTypePtr(), 0);
+    if (!DefaultType)
+      return;
+
+    ProgramStateRef State = ConstructorCall->getState();
+    State = State->set<VariantHeldTypeMap>(ThisMemRegion, *DefaultType);
+    C.addTransition(State);
+  }
+
+  bool handleStdGetCall(const CallEvent &Call, CheckerContext &C) const {
+    ProgramStateRef State = Call.getState();
+
+    const auto &ArgType = Call.getArgSVal(0)
+                              .getType(C.getASTContext())
+                              ->getPointeeType()
+                              .getTypePtr();
+    // We have to make sure that the argument is an std::variant.
+    // There is another std::get with std::pair argument
+    if (!isStdVariant(ArgType))
+      return false;
+
+    // Get the mem region of the argument std::variant and look up the type
+    // information that we know about it.
+    const MemRegion *ArgMemRegion = Call.getArgSVal(0).getAsRegion();
+    const QualType *StoredType = State->get<VariantHeldTypeMap>(ArgMemRegion);
+    if (!StoredType)
+      return false;
+
+    const CallExpr *CE = cast<CallExpr>(Call.getOriginExpr());
+    const FunctionDecl *FD = CE->getDirectCallee();
+    if (FD->getTemplateSpecializationArgs()->size() < 1)
+      return false;
+
+    const auto &TypeOut = FD->getTemplateSpecializationArgs()->asArray()[0];
+    // std::get's first template parameter can be the type we want to get
+    // out of the std::variant or a natural number which is the position of
+    // the requested type in the argument type list of the std::variant's
+    // argument.
+    QualType RetrievedType;
+    switch (TypeOut.getKind()) {
+    case TemplateArgument::ArgKind::Type:
+      RetrievedType = TypeOut.getAsType();
+      break;
+    case TemplateArgument::ArgKind::Integral:
+      // In the natural number case we look up which type corresponds to the
+      // number.
+      if (std::optional<QualType> NthTemplate =
+              getNthTemplateTypeArgFromVariant(
+                  ArgType, TypeOut.getAsIntegral().getSExtValue())) {
+        RetrievedType = *NthTemplate;
+        break;
+      }
+      [[fallthrough]];
+    default:
+      return false;
+    }
+
+    QualType RetrievedCanonicalType = RetrievedType.getCanonicalType();
+    QualType StoredCanonicalType = StoredType->getCanonicalType();
+    if (RetrievedCanonicalType == StoredCanonicalType)
+      return true;
+
+    ExplodedNode *ErrNode = C.generateNonFatalErrorNode();
+    if (!ErrNode)
+      return false;
+    llvm::SmallString<128> Str;
+    llvm::raw_svector_ostream OS(Str);
+    std::string StoredTypeName = StoredType->getAsString();
+    std::string RetrievedTypeName = RetrievedType.getAsString();
+    OS << "std::variant " << ArgMemRegion->getDescriptiveName() << " held "
+       << indefiniteArticleBasedOnVowel(StoredTypeName[0]) << " \'"
+       << StoredTypeName << "\', not "
+       << indefiniteArticleBasedOnVowel(RetrievedTypeName[0]) << " \'"
+       << RetrievedTypeName << "\'";
+    auto R = std::make_unique<PathSensitiveBugReport>(BadVariantType, OS.str(),
+                                                      ErrNode);
+    C.emitReport(std::move(R));
+    return true;
+  }
+};
+
+bool clang::ento::shouldRegisterStdVariantChecker(
+    clang::ento::CheckerManager const &mgr) {
+  return true;
+}
+
+void clang::ento::registerStdVariantChecker(clang::ento::CheckerManager &mgr) {
+  mgr.registerChecker<StdVariantChecker>();
+}
\ No newline at end of file
diff --git a/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h b/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
new file mode 100644
index 000000000000000..593f243e84ca686
--- /dev/null
+++ b/clang/lib/StaticAnalyzer/Checkers/TaggedUnionModeling.h
@@ -0,0 +1,104 @@
+//===- TaggedUnionModeling.h -------------------------------------*- C++ -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_LIB_STATICANALYZER_CHECKER_VARIANTLIKETYPEMODELING_H
+#define LLVM_CLANG_LIB_STATICANALYZER_CHECKER_VARIANTLIKETYPEMODELING_H
+
+#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
+#include "clang/StaticAnalyzer/Core/Checker.h"
+#include "clang/StaticAnalyzer/Core/CheckerManager.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
+#include "llvm/ADT/FoldingSet.h"
+#include <numeric>
+
+namespace clang {
+namespace ento {
+namespace tagged_union_modeling {
+
+// The implementation of all these functions can be found in the file
+// StdVariantChecker.cpp under the same directory as this file.
+CallEventRef<> getCaller(const CallEvent &Call, CheckerContext &C);
+bool isCopyConstructorCall(const CallEvent &Call);
+bool isCopyAssignmentCall(const CallEvent &Call);
+bool isMoveAssignmentCall(const CallEvent &Call);
+bool isMoveConstructorCall(const CallEvent &Call);
+bool isStdType(const Type *Type, const std::string &TypeName);
+bool isStdVariant(const Type *Type);
+bool calledFromSystemHeader(const CallEvent &Call, CheckerContext &C);
+
+// When invalidating regions, we also have to follow that by invalidating the
+// corresponding custom data in the program state.
+template <class TypeMap>
+ProgramStateRef
+removeInformationStoredForDeadInstances(const CallEvent *Call,
+                                        ProgramStateRef State,
+                                        ArrayRef<const MemRegion *> Regions) {
+  // If we do not know anything about the call we shall not continue.
+  // If the call is happens within a system header it is implementation detail.
+  // We should not take it into consideration.
+  if (!Call || Call->isInSystemHeader())
+    return State;
+
+  for (const MemRegion *Region : Regions)
+    State = State->remove<TypeMap>(Region);
+
+  return State;
+}
+
+template <class TypeMap>
+void handleConstructorAndAssignment(const CallEvent &Call, CheckerContext &C,
+                                    const SVal &ThisSVal) {
+  ProgramStateRef State = Call.getState();
+
+  if (!State)
+    return;
+
+  auto ArgSVal = Call.getArgSVal(0);
+  const auto *ThisRegion = ThisSVal.getAsRegion();
+  const auto *ArgMemRegion = ArgSVal.getAsRegion();
+
+  // Make changes to the state according to type of constructor/assignment
+  bool IsCopy = isCopyConstructorCall(Call) || isCopyAssignmentCall(Call);
+  bool IsMove = isMoveConstructorCall(Call) || isMoveAssignmentCall(Call);
+  // First we handle copy and move operations
+  if (IsCopy || IsMove) {
+    const QualType *OtherQType = State->get<TypeMap>(ArgMemRegion);
+
+    // If the argument of a copy constructor or assignment is unknown then
+    // we will not know the argument of the copied to object.
+    if (!OtherQType) {
+      State = State->remove<TypeMap>(ThisRegion);
+    } else {
+      // When move semantics is used we can only know that the moved from
+      // object must be in a destructible state. Other usage of the object
+      // than destruction is undefined.
+      if (IsMove)
+        State = State->remove<TypeMap>(ArgMemRegion);
+
+      State = State->set<TypeMap>(ThisRegion, *OtherQType);
+    }
+  } else {
+    // Value constructor
+    auto ArgQType = ArgSVal.getType(C.getASTContext());
+    const Type *ArgTypePtr = ArgQType.getTypePtr();
+
+    QualType WoPointer = ArgTypePtr->getPointeeType();
+    State = State->set<TypeMap>(ThisRegion, WoPointer);
+  }
+
+  C.addTransition(State);
+}
+
+} // namespace tagged_union_modeling
+} // namespace ento
+} // namespace clang
+
+#endif // LLVM_CLANG_LIB_STATICANALYZER_CHECKER_VARIANTLIKETYPEMODELING_H
\ No newline at end of file
diff --git a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h
index 8633a8beadbff33..3ef7af2ea6c6ab4 100644
--- a/clang/test/Analysis/Inputs/system-header-simulator-cxx.h
+++ b/clang/test/Analysis/Inputs/system-header-simulator-cxx.h
@@ -249,6 +249,11 @@ namespace std {
     pair(const pair<U1, U2> &other) : first(other.first),
                                       second(other.second) {}
   };
+
+  template<class T2, class T1>
+  T2& get(pair<T1, T2>& p) ;
+  template<class T1, class T2>
+  T1& get(const pair<T1, T2>& p) ;
   
   typedef __typeof__(sizeof(int)) size_t;
 
@@ -264,6 +269,9 @@ namespace std {
     return static_cast<RvalRef>(a);
   }
 
+  template< class T >
+  using remove_reference_t = typename remove_reference<T>::type;
+
   template <class T>
   void swap(T &a, T &b) {
     T c(std::move(a));
@@ -718,6 +726,11 @@ namespace std {
   template <class _Tp, class _Up> struct  is_same           : public false_type {};
   template <class _Tp>            struct  is_same<_Tp, _Tp> : public true_type {};
 
+  #if __cplusplus >= 201703L
+  template< class T, class U >
+  inline constexpr bool is_same_v = is_same<T, U>::value;
+  #endif
+
   template <class _Tp, bool = is_const<_Tp>::value || is_reference<_Tp>::value    >
   struct __add_const             {typedef _Tp type;};
 
@@ -729,6 +742,9 @@ namespace std {
   template <class _Tp> struct  remove_const            {typedef _Tp type;};
   template <class _Tp> struct  remove_const<const _Tp> {typedef _Tp type;};
 
+  template< class T >
+  using remove_const_t = typename remove_const<T>::type;
+
   template <class _Tp> struct  add_lvalue_reference    {typedef _Tp& type;};
 
   template <class _Tp> struct is_trivially_copy_assignable
@@ -793,6 +809,9 @@ namespace std {
     return __result;
   }
 
+  template< bool B, class T = void >
+  using enable_if_t = typename enable_if<B,T>::type;
+
   template<class InputIter, class OutputIter>
   OutputIter copy_backward(InputIter II, InputIter IE, OutputIter OI) {
     return __copy_backward(II, IE, OI);
@@ -1252,4 +1271,107 @@ template <typename Ret, typename... Args> class packaged_task<Ret(Args...)> {
   // TODO: Add some actual implementation.
 };
 
+  #if __cplusplus >= 201703L
+
+  namespace detail
+  {
+    template<class T>
+    struct type_identity { using type = T; }; // or use std::type_identity (since C++20)
+ 
+    template<class T>
+    auto try_add_pointer(int) -> type_identity<typename std::remove_reference<T>::type*>;
+    template<class T>
+    auto try_add_pointer(...) -> type_identity<T>;
+  } // namespace detail
+ 
+  template<class T>
+  struct add_pointer : decltype(detail::try_add_pointer<T>(0)) {};
+
+  template< class T >
+  using add_pointer_t = typename add_pointer<T>::type;
+
+  template<class T> struct remove_cv { typedef T type; };
+  template<class T> struct remove_cv<const T> { typedef T type; };
+  template<class T> struct remove_cv<volatile T> { typedef T type; };
+  template<class T> struct remove_cv<const volatile T> { typedef T type; };
+
+  template< class T >
+  using remove_cv_t = typename remove_cv<T>::type;
+
+  // This decay does not behave exactly like std::decay, but this is enough
+  // for testing the std::variant checker
+  template<class T>
+  struct decay{typedef remove_cv_t<remove_reference_t<T>> type;};
+  template<class T>
+  using decay_t = typename decay<T>::type;
+  
+  // variant
+  template <class... Types> class variant;
+  // variant helper classes
+  template <class T> struct variant_size;
+  template <class T> struct variant_size<const T>;
+  template <class T> struct variant_size<volatile T>;
+  template <class T> struct variant_size<const volatile T>;
+  template <class T> inline constexpr size_t variant_size_v = variant_size<T>::value;
+  template <class... Types>
+  struct variant_size<variant<Types...>>;
+  template <size_t I, class T> struct variant_alternative;
+  template <size_t I, class T> struct variant_alternative<I, const T>;
+  template <size_t I, class T> struct variant_alternative<I, volatile T>;
+  template <size_t I, class T> struct variant_alternative<I, const volatile T>;
+  template <size_t I, class T>
+  using variant_alternative_t = typename variant_alternative<I, T>::type;
+  template <size_t I, class... Types>
+  struct variant_alternative<I, variant<Types...>>;
+  inline constexpr size_t variant_npos = -1;
+  template <size_t I, class... Types>
+  constexpr variant_alternative_t<I, variant<Types...>>&
+    get(variant<Types...>&);
+  template <size_t I, class... Types>
+  constexpr variant_alternative_t<I, variant<Types...>>&&
+    get(variant<Types...>&&);
+  template <size_t I, class... Types>
+  constexpr const variant_alternative_t<I, variant<Types...>>&
+    get(const variant<Types...>&);
+  template <size_t I, class... Types>
+  constexpr const variant_alternative_t<I, variant<Types...>>&&
+    get(const variant<Types...>&&);
+  template <class T, class... Types>
+  constexpr T& get(variant<Types...>&);
+  template <class T, class... Types>
+  constexpr T&& get(variant<Types...>&&);
+  template <class T, class... Types>
+  constexpr const T& get(const variant<Types...>&);
+  template <class T, class... Types>
+  constexpr const T&& get(const variant<Types...>&&);
+  template <size_t I, class... Types>
+  constexpr add_pointer_t<variant_alternative_t<I, variant<Types...>>>
+    get_if(variant<Types...>*) noexcept;
+  template <size_t I, class... Types>
+  constexpr add_pointer_t<const variant_alternative_t<I, variant<Types...>>>
+    get_if(const variant<Types...>*) noexcept;
+  template <class T, class... Types>
+  constexpr add_pointer_t<T> get_if(variant<Types...>*) noexcept;
+  template <class T, class... Types>
+  constexpr add_pointer_t<const T> get_if(const variant<Types...>*) noexcept;
+
+  template <class... Types>
+  class variant {
+  public:
+    // constructors
+    constexpr variant()= default ;
+    constexpr variant(const variant&);
+    constexpr variant(variant&&);
+    template<typename T,
+            typename = std::enable_if_t<!is_same_v<std::variant<Types...>, decay_t<T>>>>
+	  constexpr variant(T&&);
+    // assignment
+    variant& operator=(const variant&);
+    variant& operator=(variant&&) ;
+    template<typename T,
+            typename = std::enable_if_t<!is_same_v<std::variant<Types...>, decay_t<T>>>>
+    variant& operator=(T&&);
+  };
+  #endif
+
 } // namespace std
diff --git a/clang/test/Analysis/diagnostics/explicit-suppression.cpp b/clang/test/Analysis/diagnostics/explicit-suppression.cpp
index b98d0260b096594..24586e37fe207a5 100644
--- a/clang/test/Analysis/diagnostics/explicit-suppression.cpp
+++ b/clang/test/Analysis/diagnostics/explicit-suppression.cpp
@@ -19,6 +19,6 @@ class C {
 void testCopyNull(C *I, C *E) {
   std::copy(I, E, (C *)0);
 #ifndef SUPPRESSED
-  // expected-warning at ../Inputs/system-header-simulator-cxx.h:741 {{Called C++ object pointer is null}}
+  // expected-warning at ../Inputs/system-header-simulator-cxx.h:757 {{Called C++ object pointer is null}}
 #endif
 }
diff --git a/clang/test/Analysis/std-variant-checker.cpp b/clang/test/Analysis/std-variant-checker.cpp
new file mode 100644
index 000000000000000..7f136c06b19cc60
--- /dev/null
+++ b/clang/test/Analysis/std-variant-checker.cpp
@@ -0,0 +1,358 @@
+// RUN: %clang %s -std=c++17 -Xclang -verify --analyze \
+// RUN:   -Xclang -analyzer-checker=core \
+// RUN:   -Xclang -analyzer-checker=debug.ExprInspection \
+// RUN:   -Xclang -analyzer-checker=core,alpha.core.StdVariant
+
+#include "Inputs/system-header-simulator-cxx.h"
+
+class Foo{};
+
+void clang_analyzer_warnIfReached();
+void clang_analyzer_eval(int);
+
+//helper functions
+void changeVariantType(std::variant<int, char> &v) {
+  v = 25;
+}
+
+void changesToInt(std::variant<int, char> &v);
+void changesToInt(std::variant<int, char> *v);
+
+void cannotChangePtr(const std::variant<int, char> &v);
+void cannotChangePtr(const std::variant<int, char> *v);
+
+char getUnknownChar();
+
+void swap(std::variant<int, char> &v1, std::variant<int, char> &v2) {
+  std::variant<int, char> tmp = v1;
+  v1 = v2;
+  v2 = tmp;
+}
+
+void cantDo(const std::variant<int, char>& v) {
+  std::variant<int, char> vtmp = v;
+  vtmp = 5;
+  int a = std::get<int> (vtmp);
+  (void) a;
+}
+
+void changeVariantPtr(std::variant<int, char> *v) {
+  *v = 'c';
+}
+
+using var_t = std::variant<int, char>;
+using var_tt = var_t;
+using int_t = int;
+using char_t = char;
+
+// A quick sanity check to see that std::variant's std::get
+// is not being confused with std::pairs std::get.
+void wontConfuseStdGets() {
+  std::pair<int, char> p{15, '1'};
+  int a = std::get<int>(p);
+  char c = std::get<char>(p);
+  (void)a;
+  (void)c;
+}
+
+//----------------------------------------------------------------------------//
+// std::get
+//----------------------------------------------------------------------------//
+void stdGetType() {
+  std::variant<int, char> v = 25;
+  int a = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void stdGetPointer() {
+  int *p = new int;
+  std::variant<int*, char> v = p;
+  int *a = std::get<int*>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int *', not a 'char'}}
+  (void)a;
+  (void)c;
+  delete p;
+}
+
+void stdGetObject() {
+  std::variant<int, char, Foo> v = Foo{};
+  Foo f = std::get<Foo>(v);
+  int i = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'Foo', not an 'int'}}
+  (void)i;
+}
+
+void stdGetPointerAndPointee() {
+  int a = 5;
+  std::variant<int, int*> v = &a;
+  int *b = std::get<int*>(v);
+  int c = std::get<int>(v); // expected-warning {{std::variant 'v' held an 'int *', not an 'int'}}
+  (void)c;
+  (void)b;
+}
+
+void variantHoldingVariant() {
+  std::variant<std::variant<int, char>, std::variant<char, int>> v = std::variant<int,char>(25);
+  std::variant<int, char> v1 = std::get<std::variant<int,char>>(v);
+  std::variant<char, int> v2 = std::get<std::variant<char,int>>(v); // expected-warning {{std::variant 'v' held a 'std::variant<int, char>', not a 'class std::variant<char, int>'}}
+}
+
+//----------------------------------------------------------------------------//
+// Constructors and assignments
+//----------------------------------------------------------------------------//
+void copyConstructor() {
+  std::variant<int, char> v = 25;
+  std::variant<int, char> t(v);
+  int a = std::get<int> (t);
+  char c = std::get<char> (t); // expected-warning {{std::variant 't' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void copyAssignmentOperator() {
+  std::variant<int, char> v = 25;
+  std::variant<int, char> t = 'c';
+  t = v;
+  int a = std::get<int> (t);
+  char c = std::get<char> (t); // expected-warning {{std::variant 't' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void assignmentOperator() {
+  std::variant<int, char> v = 25;
+  int a = std::get<int> (v);
+  (void)a;
+  v = 'c';
+  char c = std::get<char>(v);
+  a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void typeChangeThreeTimes() {
+  std::variant<int, char, float> v = 25;
+  int a = std::get<int> (v);
+  (void)a;
+  v = 'c';
+  char c = std::get<char>(v);
+  v = 25;
+  a = std::get<int>(v);
+  (void)a;
+  v = 1.25f;
+  float f = std::get<float>(v);
+  a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'float', not an 'int'}}
+  (void)a;
+  (void)c;
+  (void)f;
+}
+
+void defaultConstructor() {
+  std::variant<int, char> v;
+  int i = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)i;
+  (void)c;
+}
+
+// Verify that we handle temporary objects correctly
+void temporaryObjectsConstructor() {
+  std::variant<int, char> v(std::variant<int, char>('c'));
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void temporaryObjectsAssignment() {
+  std::variant<int, char> v = std::variant<int, char>('c');
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+// Verify that we handle pointer types correctly
+void pointerTypeHeld() {
+  int *p = new int;
+  std::variant<int*, char> v = p;
+  int *a = std::get<int*>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int *', not a 'char'}}
+  (void)a;
+  (void)c;
+  delete p;
+}
+
+std::variant<int, char> get_unknown_variant();
+// Verify that the copy constructor is handles properly when the std::variant
+// has no previously activated type and we copy an object of unknown value in it.
+void copyFromUnknownVariant() {
+  std::variant<int, char> u = get_unknown_variant();
+  std::variant<int, char> v(u);
+  int a = std::get<int>(v); // no-waring
+  char c = std::get<char>(v); // no-warning
+  (void)a;
+  (void)c;
+}
+
+// Verify that the copy constructor is handles properly when the std::variant
+// has previously activated type and we copy an object of unknown value in it.
+void copyFromUnknownVariantBef() {
+  std::variant<int, char> v = 25;
+  std::variant<int, char> u = get_unknown_variant();
+  v = u;
+  int a = std::get<int>(v); // no-waring
+  char c = std::get<char>(v); // no-warning
+  (void)a;
+  (void)c;
+}
+
+//----------------------------------------------------------------------------//
+// typedef
+//----------------------------------------------------------------------------//
+
+void typefdefedVariant() {
+  var_t v = 25;
+  int a = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void typedefedTypedfefedVariant() {
+  var_tt v = 25;
+  int a = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void typedefedGet() {
+  std::variant<char, int> v = 25;
+  int a = std::get<int_t>(v);
+  char c = std::get<char_t>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void typedefedPack() {
+  std::variant<int_t, char_t> v = 25;
+  int a = std::get<int>(v);
+  char c = std::get<char>(v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+void fromVariable() {
+  char o = 'c';
+  std::variant<int, char> v(o);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void unknowValueButKnownType() {
+  char o = getUnknownChar();
+  std::variant<int, char> v(o);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void createPointer() {
+  std::variant<int, char> *v = new std::variant<int, char>(15);
+  int a = std::get<int>(*v);
+  char c = std::get<char>(*v); // expected-warning {{std::variant  held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+  delete v;
+}
+
+//----------------------------------------------------------------------------//
+// Passing std::variants to functions
+//----------------------------------------------------------------------------//
+
+// Verifying that we are not invalidating the memory region of a variant if
+// a non inlined or inlined function takes it as a constant reference or pointer
+void constNonInlineRef() {
+  std::variant<int, char> v = 'c';
+  cannotChangePtr(v);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void contNonInlinePtr() {
+  std::variant<int, char> v = 'c';
+  cannotChangePtr(&v);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void copyInAFunction() {
+  std::variant<int, char> v = 'c';
+  cantDo(v);
+  char c = std::get<char>(v);
+  int a = std::get<int>(v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+
+}
+
+// Verifying that we can keep track of the type stored in std::variant when
+// it is passed to an inlined function as a reference or pointer
+void changeThruPointers() {
+  std::variant<int, char> v = 15;
+  changeVariantPtr(&v);
+  char c = std::get<char> (v);
+  int a = std::get<int> (v); // expected-warning {{std::variant 'v' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void functionCallWithCopyAssignment() {
+  var_t v1 = 15;
+  var_t v2 = 'c';
+  swap(v1, v2);
+  int a = std::get<int> (v2);
+  (void)a;
+  char c = std::get<char> (v1);
+  a = std::get<int> (v1); // expected-warning {{std::variant 'v1' held a 'char', not an 'int'}}
+  (void)a;
+  (void)c;
+}
+
+void inlineFunctionCall() {
+  std::variant<int, char> v = 'c';
+  changeVariantType(v);
+  int a = std::get<int> (v);
+  char c = std::get<char> (v); // expected-warning {{std::variant 'v' held an 'int', not a 'char'}}
+  (void)a;
+  (void)c;
+}
+
+// Verifying that we invalidate the mem region of std::variant when it is
+// passed as a non const reference or a pointer to a non inlined function.
+void nonInlineFunctionCall() {
+  std::variant<int, char> v = 'c';
+  changesToInt(v);
+  int a = std::get<int> (v); // no-waring
+  char c = std::get<char> (v); // no-warning
+  (void)a;
+  (void)c;
+}
+
+void nonInlineFunctionCallPtr() {
+  std::variant<int, char> v = 'c';
+  changesToInt(&v);
+  int a = std::get<int> (v); // no-warning
+  char c = std::get<char> (v); // no-warning
+  (void)a;
+  (void)c;
+}
\ No newline at end of file



More information about the llvm-commits mailing list