[clang] [clang][dataflow] Make optional checker work for types derived from optional. (PR #84138)

via cfe-commits cfe-commits at lists.llvm.org
Wed Mar 6 00:28:44 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: None (martinboehme)

<details>
<summary>Changes</summary>

`llvm::MaybeAlign` does this, for example.

It's not an option to simply ignore these derived classes because they get cast
back to the optional classes (for example, simply when calling the optional
member functions), and our transfer functions will then run on those optional
classes and therefore require them to be properly initialized.


---
Full diff: https://github.com/llvm/llvm-project/pull/84138.diff


2 Files Affected:

- (modified) clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp (+81-67) 
- (modified) clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp (+30) 


``````````diff
diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
index 1d31b22b6d25ff..73dbc027a239e9 100644
--- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
+++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
@@ -64,28 +64,43 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
   return false;
 }
 
+static const CXXRecordDecl *getOptionalBaseClass(const CXXRecordDecl *RD) {
+  if (RD == nullptr)
+    return nullptr;
+  if (hasOptionalClassName(*RD))
+    return RD;
+
+  if (!RD->hasDefinition())
+    return nullptr;
+
+  for (const CXXBaseSpecifier &Base : RD->bases())
+    if (Base.getAccessSpecifier() == AS_public)
+      if (const CXXRecordDecl *BaseClass =
+              getOptionalBaseClass(Base.getType()->getAsCXXRecordDecl()))
+        return BaseClass;
+
+  return nullptr;
+}
+
 namespace {
 
 using namespace ::clang::ast_matchers;
 using LatticeTransferState = TransferState<NoopLattice>;
 
-AST_MATCHER(CXXRecordDecl, hasOptionalClassNameMatcher) {
-  return hasOptionalClassName(Node);
+AST_MATCHER(CXXRecordDecl, optionalOrDerivedClass) {
+  return getOptionalBaseClass(&Node) != nullptr;
 }
 
-DeclarationMatcher optionalClass() {
-  return classTemplateSpecializationDecl(
-      hasOptionalClassNameMatcher(),
-      hasTemplateArgument(0, refersToType(type().bind("T"))));
-}
-
-auto optionalOrAliasType() {
+auto desugarsToOptionalOrDerivedType() {
   return hasUnqualifiedDesugaredType(
-      recordType(hasDeclaration(optionalClass())));
+      recordType(hasDeclaration(cxxRecordDecl(optionalOrDerivedClass()))));
 }
 
-/// Matches any of the spellings of the optional types and sugar, aliases, etc.
-auto hasOptionalType() { return hasType(optionalOrAliasType()); }
+/// Matches any of the spellings of the optional types and sugar, aliases,
+/// derived classes, etc.
+auto hasOptionalOrDerivedType() {
+  return hasType(desugarsToOptionalOrDerivedType());
+}
 
 auto isOptionalMemberCallWithNameMatcher(
     ast_matchers::internal::Matcher<NamedDecl> matcher,
@@ -93,9 +108,9 @@ auto isOptionalMemberCallWithNameMatcher(
   auto Exception = unless(Ignorable ? expr(anyOf(*Ignorable, cxxThisExpr()))
                                     : cxxThisExpr());
   return cxxMemberCallExpr(
-      on(expr(Exception,
-              anyOf(hasOptionalType(),
-                    hasType(pointerType(pointee(optionalOrAliasType())))))),
+      on(expr(Exception, anyOf(hasOptionalOrDerivedType(),
+                               hasType(pointerType(pointee(
+                                   desugarsToOptionalOrDerivedType())))))),
       callee(cxxMethodDecl(matcher)));
 }
 
@@ -104,7 +119,7 @@ auto isOptionalOperatorCallWithName(
     const std::optional<StatementMatcher> &Ignorable = std::nullopt) {
   return cxxOperatorCallExpr(
       hasOverloadedOperatorName(operator_name),
-      callee(cxxMethodDecl(ofClass(optionalClass()))),
+      callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
       Ignorable ? callExpr(unless(hasArgument(0, *Ignorable))) : callExpr());
 }
 
@@ -112,7 +127,7 @@ auto isMakeOptionalCall() {
   return callExpr(callee(functionDecl(hasAnyName(
                       "std::make_optional", "base::make_optional",
                       "absl::make_optional", "folly::make_optional"))),
-                  hasOptionalType());
+                  hasOptionalOrDerivedType());
 }
 
 auto nulloptTypeDecl() {
@@ -129,19 +144,19 @@ auto inPlaceClass() {
 
 auto isOptionalNulloptConstructor() {
   return cxxConstructExpr(
-      hasOptionalType(),
+      hasOptionalOrDerivedType(),
       hasDeclaration(cxxConstructorDecl(parameterCountIs(1),
                                         hasParameter(0, hasNulloptType()))));
 }
 
 auto isOptionalInPlaceConstructor() {
-  return cxxConstructExpr(hasOptionalType(),
+  return cxxConstructExpr(hasOptionalOrDerivedType(),
                           hasArgument(0, hasType(inPlaceClass())));
 }
 
 auto isOptionalValueOrConversionConstructor() {
   return cxxConstructExpr(
-      hasOptionalType(),
+      hasOptionalOrDerivedType(),
       unless(hasDeclaration(
           cxxConstructorDecl(anyOf(isCopyConstructor(), isMoveConstructor())))),
       argumentCountIs(1), hasArgument(0, unless(hasNulloptType())));
@@ -150,28 +165,30 @@ auto isOptionalValueOrConversionConstructor() {
 auto isOptionalValueOrConversionAssignment() {
   return cxxOperatorCallExpr(
       hasOverloadedOperatorName("="),
-      callee(cxxMethodDecl(ofClass(optionalClass()))),
+      callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
       unless(hasDeclaration(cxxMethodDecl(
           anyOf(isCopyAssignmentOperator(), isMoveAssignmentOperator())))),
       argumentCountIs(2), hasArgument(1, unless(hasNulloptType())));
 }
 
 auto isOptionalNulloptAssignment() {
-  return cxxOperatorCallExpr(hasOverloadedOperatorName("="),
-                             callee(cxxMethodDecl(ofClass(optionalClass()))),
-                             argumentCountIs(2),
-                             hasArgument(1, hasNulloptType()));
+  return cxxOperatorCallExpr(
+      hasOverloadedOperatorName("="),
+      callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
+      argumentCountIs(2), hasArgument(1, hasNulloptType()));
 }
 
 auto isStdSwapCall() {
   return callExpr(callee(functionDecl(hasName("std::swap"))),
-                  argumentCountIs(2), hasArgument(0, hasOptionalType()),
-                  hasArgument(1, hasOptionalType()));
+                  argumentCountIs(2),
+                  hasArgument(0, hasOptionalOrDerivedType()),
+                  hasArgument(1, hasOptionalOrDerivedType()));
 }
 
 auto isStdForwardCall() {
   return callExpr(callee(functionDecl(hasName("std::forward"))),
-                  argumentCountIs(1), hasArgument(0, hasOptionalType()));
+                  argumentCountIs(1),
+                  hasArgument(0, hasOptionalOrDerivedType()));
 }
 
 constexpr llvm::StringLiteral ValueOrCallID = "ValueOrCall";
@@ -181,10 +198,11 @@ auto isValueOrStringEmptyCall() {
   return cxxMemberCallExpr(
       callee(cxxMethodDecl(hasName("empty"))),
       onImplicitObjectArgument(ignoringImplicit(
-          cxxMemberCallExpr(on(expr(unless(cxxThisExpr()))),
-                            callee(cxxMethodDecl(hasName("value_or"),
-                                                 ofClass(optionalClass()))),
-                            hasArgument(0, stringLiteral(hasSize(0))))
+          cxxMemberCallExpr(
+              on(expr(unless(cxxThisExpr()))),
+              callee(cxxMethodDecl(hasName("value_or"),
+                                   ofClass(optionalOrDerivedClass()))),
+              hasArgument(0, stringLiteral(hasSize(0))))
               .bind(ValueOrCallID))));
 }
 
@@ -192,10 +210,11 @@ auto isValueOrNotEqX() {
   auto ComparesToSame = [](ast_matchers::internal::Matcher<Stmt> Arg) {
     return hasOperands(
         ignoringImplicit(
-            cxxMemberCallExpr(on(expr(unless(cxxThisExpr()))),
-                              callee(cxxMethodDecl(hasName("value_or"),
-                                                   ofClass(optionalClass()))),
-                              hasArgument(0, Arg))
+            cxxMemberCallExpr(
+                on(expr(unless(cxxThisExpr()))),
+                callee(cxxMethodDecl(hasName("value_or"),
+                                     ofClass(optionalOrDerivedClass()))),
+                hasArgument(0, Arg))
                 .bind(ValueOrCallID)),
         ignoringImplicit(Arg));
   };
@@ -212,8 +231,9 @@ auto isValueOrNotEqX() {
 }
 
 auto isCallReturningOptional() {
-  return callExpr(hasType(qualType(anyOf(
-      optionalOrAliasType(), referenceType(pointee(optionalOrAliasType()))))));
+  return callExpr(hasType(qualType(
+      anyOf(desugarsToOptionalOrDerivedType(),
+            referenceType(pointee(desugarsToOptionalOrDerivedType()))))));
 }
 
 template <typename L, typename R>
@@ -275,12 +295,9 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) {
   return HasValueVal;
 }
 
-/// Returns true if and only if `Type` is an optional type.
-bool isOptionalType(QualType Type) {
-  if (!Type->isRecordType())
-    return false;
-  const CXXRecordDecl *D = Type->getAsCXXRecordDecl();
-  return D != nullptr && hasOptionalClassName(*D);
+QualType valueTypeFromOptionalDecl(const CXXRecordDecl &RD) {
+  auto &CTSD = cast<ClassTemplateSpecializationDecl>(RD);
+  return CTSD.getTemplateArgs()[0].getAsType();
 }
 
 /// Returns the number of optional wrappers in `Type`.
@@ -288,15 +305,13 @@ bool isOptionalType(QualType Type) {
 /// For example, if `Type` is `optional<optional<int>>`, the result of this
 /// function will be 2.
 int countOptionalWrappers(const ASTContext &ASTCtx, QualType Type) {
-  if (!isOptionalType(Type))
+  const CXXRecordDecl *Optional =
+      getOptionalBaseClass(Type->getAsCXXRecordDecl());
+  if (Optional == nullptr)
     return 0;
   return 1 + countOptionalWrappers(
                  ASTCtx,
-                 cast<ClassTemplateSpecializationDecl>(Type->getAsRecordDecl())
-                     ->getTemplateArgs()
-                     .get(0)
-                     .getAsType()
-                     .getDesugaredType(ASTCtx));
+                 valueTypeFromOptionalDecl(*Optional).getDesugaredType(ASTCtx));
 }
 
 StorageLocation *getLocBehindPossiblePointer(const Expr &E,
@@ -623,7 +638,7 @@ ignorableOptional(const UncheckedOptionalAccessModelOptions &Options) {
   if (Options.IgnoreSmartPointerDereference) {
     auto SmartPtrUse = expr(ignoringParenImpCasts(cxxOperatorCallExpr(
         anyOf(hasOverloadedOperatorName("->"), hasOverloadedOperatorName("*")),
-        unless(hasArgument(0, expr(hasOptionalType()))))));
+        unless(hasArgument(0, expr(hasOptionalOrDerivedType()))))));
     return expr(
         anyOf(SmartPtrUse, memberExpr(hasObjectExpression(SmartPtrUse))));
   }
@@ -758,32 +773,35 @@ auto buildTransferMatchSwitch() {
 
       // Comparisons (==, !=):
       .CaseOfCFGStmt<CXXOperatorCallExpr>(
-          isComparisonOperatorCall(hasOptionalType(), hasOptionalType()),
+          isComparisonOperatorCall(hasOptionalOrDerivedType(),
+                                   hasOptionalOrDerivedType()),
           transferOptionalAndOptionalCmp)
       .CaseOfCFGStmt<CXXOperatorCallExpr>(
-          isComparisonOperatorCall(hasOptionalType(), hasNulloptType()),
+          isComparisonOperatorCall(hasOptionalOrDerivedType(),
+                                   hasNulloptType()),
           [](const clang::CXXOperatorCallExpr *Cmp,
              const MatchFinder::MatchResult &, LatticeTransferState &State) {
             transferOptionalAndNulloptCmp(Cmp, Cmp->getArg(0), State.Env);
           })
       .CaseOfCFGStmt<CXXOperatorCallExpr>(
-          isComparisonOperatorCall(hasNulloptType(), hasOptionalType()),
+          isComparisonOperatorCall(hasNulloptType(),
+                                   hasOptionalOrDerivedType()),
           [](const clang::CXXOperatorCallExpr *Cmp,
              const MatchFinder::MatchResult &, LatticeTransferState &State) {
             transferOptionalAndNulloptCmp(Cmp, Cmp->getArg(1), State.Env);
           })
       .CaseOfCFGStmt<CXXOperatorCallExpr>(
           isComparisonOperatorCall(
-              hasOptionalType(),
-              unless(anyOf(hasOptionalType(), hasNulloptType()))),
+              hasOptionalOrDerivedType(),
+              unless(anyOf(hasOptionalOrDerivedType(), hasNulloptType()))),
           [](const clang::CXXOperatorCallExpr *Cmp,
              const MatchFinder::MatchResult &, LatticeTransferState &State) {
             transferOptionalAndValueCmp(Cmp, Cmp->getArg(0), State.Env);
           })
       .CaseOfCFGStmt<CXXOperatorCallExpr>(
           isComparisonOperatorCall(
-              unless(anyOf(hasOptionalType(), hasNulloptType())),
-              hasOptionalType()),
+              unless(anyOf(hasOptionalOrDerivedType(), hasNulloptType())),
+              hasOptionalOrDerivedType()),
           [](const clang::CXXOperatorCallExpr *Cmp,
              const MatchFinder::MatchResult &, LatticeTransferState &State) {
             transferOptionalAndValueCmp(Cmp, Cmp->getArg(1), State.Env);
@@ -843,13 +861,7 @@ auto buildDiagnoseMatchSwitch(
 
 ast_matchers::DeclarationMatcher
 UncheckedOptionalAccessModel::optionalClassDecl() {
-  return optionalClass();
-}
-
-static QualType valueTypeFromOptionalType(QualType OptionalTy) {
-  auto *CTSD =
-      cast<ClassTemplateSpecializationDecl>(OptionalTy->getAsCXXRecordDecl());
-  return CTSD->getTemplateArgs()[0].getAsType();
+  return cxxRecordDecl(optionalOrDerivedClass());
 }
 
 UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
@@ -858,9 +870,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
       TransferMatchSwitch(buildTransferMatchSwitch()) {
   Env.getDataflowAnalysisContext().setSyntheticFieldCallback(
       [&Ctx](QualType Ty) -> llvm::StringMap<QualType> {
-        if (!isOptionalType(Ty))
+        const CXXRecordDecl *Optional =
+            getOptionalBaseClass(Ty->getAsCXXRecordDecl());
+        if (Optional == nullptr)
           return {};
-        return {{"value", valueTypeFromOptionalType(Ty)},
+        return {{"value", valueTypeFromOptionalDecl(*Optional)},
                 {"has_value", Ctx.BoolTy}};
       });
 }
diff --git a/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp b/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp
index b6e4973fd7cb2b..c1d5b2183007ae 100644
--- a/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp
+++ b/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp
@@ -3383,6 +3383,36 @@ TEST_P(UncheckedOptionalAccessTest, LambdaCaptureStateNotPropagated) {
     }
   )");
 }
+
+TEST_P(UncheckedOptionalAccessTest, ClassDerivedFromOptional) {
+  ExpectDiagnosticsFor(R"(
+    #include "unchecked_optional_access_test.h"
+
+    struct Derived : public $ns::$optional<int> {};
+
+    void target(Derived opt) {
+      *opt;  // [[unsafe]]
+      if (opt)
+        *opt;
+    }
+  )");
+}
+
+TEST_P(UncheckedOptionalAccessTest, ClassTemplateDerivedFromOptional) {
+  ExpectDiagnosticsFor(R"(
+    #include "unchecked_optional_access_test.h"
+
+    template <class T>
+    struct Derived : public $ns::$optional<T> {};
+
+    void target(Derived<int> opt) {
+      *opt;  // [[unsafe]]
+      if (opt)
+        *opt;
+    }
+  )");
+}
+
 // FIXME: Add support for:
 // - constructors (copy, move)
 // - assignment operators (default, copy, move)

``````````

</details>


https://github.com/llvm/llvm-project/pull/84138


More information about the cfe-commits mailing list