[clang] [analyzer] Add std::variant checker (PR #66481)

via cfe-commits cfe-commits at lists.llvm.org
Thu Sep 21 03:40:26 PDT 2023


================
@@ -0,0 +1,324 @@
+//===- StdVariantChecker.cpp -------------------------------------*- C++ -*-==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/Type.h"
+#include "clang/StaticAnalyzer/Checkers/BuiltinCheckerRegistration.h"
+#include "clang/StaticAnalyzer/Core/BugReporter/BugType.h"
+#include "clang/StaticAnalyzer/Core/Checker.h"
+#include "clang/StaticAnalyzer/Core/CheckerManager.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallDescription.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
+#include "clang/StaticAnalyzer/Core/PathSensitive/SVals.h"
+#include "llvm/ADT/FoldingSet.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Casting.h"
+#include <optional>
+#include <string_view>
+
+#include "TaggedUnionModeling.h"
+
+using namespace clang;
+using namespace ento;
+using namespace tagged_union_modeling;
+
+REGISTER_MAP_WITH_PROGRAMSTATE(VariantHeldTypeMap, const MemRegion *, QualType)
+
+namespace clang {
+namespace ento {
+namespace tagged_union_modeling {
+
+// Returns the CallEvent representing the caller of the function
+// It is needed because the CallEvent class does not contain enough information
+// to tell who called it. Checker context is needed.
+CallEventRef<> getCaller(const CallEvent &Call, const ProgramStateRef &State) {
+  const auto *CallLocationContext = Call.getLocationContext();
+  if (!CallLocationContext || CallLocationContext->inTopFrame())
+    return nullptr;
+
+  const auto *CallStackFrameContext = CallLocationContext->getStackFrame();
+  if (!CallStackFrameContext)
+    return nullptr;
+
+  CallEventManager &CEMgr = State->getStateManager().getCallEventManager();
+  return CEMgr.getCaller(CallStackFrameContext, State);
+}
+
+const CXXConstructorDecl *
+getConstructorDeclarationForCall(const CallEvent &Call) {
+  const auto *ConstructorCall = dyn_cast<CXXConstructorCall>(&Call);
+  if (!ConstructorCall)
+    return nullptr;
+
+  return ConstructorCall->getDecl();
+}
+
+bool isCopyConstructorCall(const CallEvent &Call) {
+  if (const CXXConstructorDecl *ConstructorDecl =
+          getConstructorDeclarationForCall(Call))
+    return ConstructorDecl->isCopyConstructor();
+  return false;
+}
+
+bool isCopyAssignmentCall(const CallEvent &Call) {
+  const Decl *CopyAssignmentDecl = Call.getDecl();
+
+  if (const auto *AsMethodDecl =
+          dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl))
+    return AsMethodDecl->isCopyAssignmentOperator();
+  return false;
+}
+
+bool isMoveConstructorCall(const CallEvent &Call) {
+  const CXXConstructorDecl *ConstructorDecl =
+      getConstructorDeclarationForCall(Call);
+  if (!ConstructorDecl)
+    return false;
+
+  return ConstructorDecl->isMoveConstructor();
+}
+
+bool isMoveAssignmentCall(const CallEvent &Call) {
+  const Decl *CopyAssignmentDecl = Call.getDecl();
+
+  const auto *AsMethodDecl =
+      dyn_cast_or_null<CXXMethodDecl>(CopyAssignmentDecl);
+  if (!AsMethodDecl)
+    return false;
+
+  return AsMethodDecl->isMoveAssignmentOperator();
+}
+
+bool isStdType(const Type *Type, llvm::StringRef TypeName) {
+  auto *Decl = Type->getAsRecordDecl();
+  if (!Decl)
+    return false;
+  return (Decl->getName() == TypeName) && Decl->isInStdNamespace();
+}
+
+bool isStdVariant(const Type *Type) {
+  return isStdType(Type, llvm::StringRef("variant"));
+}
+
+bool calledFromSystemHeader(const CallEvent &Call,
+                            const ProgramStateRef &State) {
+  if (CallEventRef<> Caller = getCaller(Call, State))
+    return Caller->isInSystemHeader();
+
+  return false;
+}
+
+bool calledFromSystemHeader(const CallEvent &Call, CheckerContext &C) {
+  return calledFromSystemHeader(Call, C.getState());
+}
+
+} // end of namespace tagged_union_modeling
+} // end of namespace ento
+} // end of namespace clang
+
+static std::optional<ArrayRef<TemplateArgument>>
+getTemplateArgsFromVariant(const Type *VariantType) {
+  const auto *TempSpecType = VariantType->getAs<TemplateSpecializationType>();
+  if (!TempSpecType)
+    return {};
+
+  return TempSpecType->template_arguments();
+}
+
+static std::optional<QualType>
+getNthTemplateTypeArgFromVariant(const Type *varType, unsigned i) {
+  std::optional<ArrayRef<TemplateArgument>> VariantTemplates =
+      getTemplateArgsFromVariant(varType);
+  if (!VariantTemplates)
+    return {};
+
+  return (*VariantTemplates)[i].getAsType();
+}
+
+static bool isVowel(char a) {
+  switch (a) {
+  case 'a':
+  case 'e':
+  case 'i':
+  case 'o':
+  case 'u':
+    return true;
+  default:
+    return false;
+  }
+}
+
+static llvm::StringRef indefiniteArticleBasedOnVowel(char a) {
+  if (isVowel(a))
+    return "an";
+  return "a";
+}
+
+class StdVariantChecker : public Checker<eval::Call, check::RegionChanges> {
+  // Call descriptors to find relevant calls
+  CallDescription VariantConstructor{{"std", "variant", "variant"}};
+  CallDescription VariantAsOp{{"std", "variant", "operator="}};
+  CallDescription StdGet{{"std", "get"}, 1, 1};
+
+  BugType BadVariantType{this, "BadVariantType", "BadVariantType"};
+
+public:
+  ProgramStateRef checkRegionChanges(ProgramStateRef State,
+                                     const InvalidatedSymbols *,
+                                     ArrayRef<const MemRegion *>,
+                                     ArrayRef<const MemRegion *> Regions,
+                                     const LocationContext *,
+                                     const CallEvent *Call) const {
+    return removeInformationStoredForDeadInstances<VariantHeldTypeMap>(
+        Call, State, Regions);
+  }
+
+  bool evalCall(const CallEvent &Call, CheckerContext &C) const {
+    // Check if the call was not made from a system header. If it was then
+    // we do an early return because it is part of the implementation
+    if (calledFromSystemHeader(Call, C))
+      return false;
+
+    if (StdGet.matches(Call))
+      return handleStdGetCall(Call, C);
+
+    // First check if a constructor call is happening. If it is a
+    // constructor call, check if it is an std::variant constructor call.
+    bool IsVariantConstructor =
+        isa<CXXConstructorCall>(Call) && VariantConstructor.matches(Call);
+    bool IsVariantAssignmentOperatorCall =
+        isa<CXXMemberOperatorCall>(Call) && VariantAsOp.matches(Call);
+
+    if (IsVariantConstructor || IsVariantAssignmentOperatorCall) {
+      if (Call.getNumArgs() == 0 && IsVariantConstructor) {
+        handleDefaultConstructor(cast<CXXConstructorCall>(&Call), C);
+        return true;
+      }
+
+      if (Call.getNumArgs() != 1)
+        return false;
+
+      SVal ThisSVal;
+      if (IsVariantConstructor) {
+        const auto *AsConstructorCall = dyn_cast<CXXConstructorCall>(&Call);
+        ThisSVal = AsConstructorCall->getCXXThisVal();
+      } else if (IsVariantAssignmentOperatorCall) {
+        const auto *AsMemberOpCall = dyn_cast<CXXMemberOperatorCall>(&Call);
+        ThisSVal = AsMemberOpCall->getCXXThisVal();
+      } else {
+        return false;
+      }
+
+      handleConstructorAndAssignment<VariantHeldTypeMap>(Call, C, ThisSVal);
+      return true;
+    }
+    return false;
+  }
+
+private:
+  // The default constructed std::variant must be handled separately
+  // by default the std::variant is going to hold a default constructed instance
+  // of the first type of the possible types
+  void handleDefaultConstructor(const CXXConstructorCall *ConstructorCall,
+                                CheckerContext &C) const {
+    SVal ThisSVal = ConstructorCall->getCXXThisVal();
+
+    const auto *const ThisMemRegion = ThisSVal.getAsRegion();
+    if (!ThisMemRegion)
+      return;
+
+    std::optional<QualType> DefaultType = getNthTemplateTypeArgFromVariant(
+        ThisSVal.getType(C.getASTContext())->getPointeeType().getTypePtr(), 0);
+    if (!DefaultType)
+      return;
+
+    ProgramStateRef State = ConstructorCall->getState();
+    State = State->set<VariantHeldTypeMap>(ThisMemRegion, *DefaultType);
+    C.addTransition(State);
+  }
+
+  bool handleStdGetCall(const CallEvent &Call, CheckerContext &C) const {
+    ProgramStateRef State = Call.getState();
+
+    const auto &ArgType = Call.getArgSVal(0)
+                              .getType(C.getASTContext())
+                              ->getPointeeType()
+                              .getTypePtr();
+    // We have to make sure that the argument is an std::variant.
+    // There is another std::get with std::pair argument
+    if (!isStdVariant(ArgType))
+      return false;
+
+    // Get the mem region of the argument std::variant and get what type
+    // information is known about it.
+    const MemRegion *ArgMemRegion = Call.getArgSVal(0).getAsRegion();
+    const QualType *StoredType = State->get<VariantHeldTypeMap>(ArgMemRegion);
+    if (!StoredType)
+      return false;
+
+    const CallExpr *CE = cast<CallExpr>(Call.getOriginExpr());
+    const FunctionDecl *FD = CE->getDirectCallee();
+    if (FD->getTemplateSpecializationArgs()->size() < 1)
+      return false;
+
+    const auto &TypeOut = FD->getTemplateSpecializationArgs()->asArray()[0];
+    // std::get's first template parameter can be the type we want to get
+    // out of the std::variant or a natural number which is the position of
+    // the requested type in the argument type list of the std::variant's
+    // argument.
+    QualType RetrievedType;
+    switch (TypeOut.getKind()) {
+    case TemplateArgument::ArgKind::Type:
+      RetrievedType = TypeOut.getAsType();
+      break;
+    case TemplateArgument::ArgKind::Integral:
+      // In the natural number case we look up which type corresponds to the
+      // number.
+      if (std::optional<QualType> NthTemplate =
+              getNthTemplateTypeArgFromVariant(
+                  ArgType, TypeOut.getAsIntegral().getSExtValue()))
+        RetrievedType = *NthTemplate;
+      else
+        return false;
+      break;
----------------
DonatNagyE wrote:

```suggestion
      if (std::optional<QualType> NthTemplate =
              getNthTemplateTypeArgFromVariant(
                  ArgType, TypeOut.getAsIntegral().getSExtValue())) {
        RetrievedType = *NthTemplate;
        break;
      }
      [[fallthrough]];
```
What would you think about this alternative? (You don't have to use it if you prefer the original.)

https://github.com/llvm/llvm-project/pull/66481


More information about the cfe-commits mailing list