[clang] d712c5e - [clang][dataflow] Make optional checker work for types derived from optional. (#84138)

via cfe-commits cfe-commits at lists.llvm.org
Tue Mar 19 04:53:54 PDT 2024


Author: martinboehme
Date: 2024-03-19T12:53:50+01:00
New Revision: d712c5ed8fab4940ae0480e01fc72a944cbb79e6

URL: https://github.com/llvm/llvm-project/commit/d712c5ed8fab4940ae0480e01fc72a944cbb79e6
DIFF: https://github.com/llvm/llvm-project/commit/d712c5ed8fab4940ae0480e01fc72a944cbb79e6.diff

LOG: [clang][dataflow] Make optional checker work for types derived from optional. (#84138)

`llvm::MaybeAlign` does this, for example.

It's not an option to simply ignore these derived classes because they
get cast
back to the optional classes (for example, simply when calling the
optional
member functions), and our transfer functions will then run on those
optional
classes and therefore require them to be properly initialized.

Added: 
    

Modified: 
    clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
    clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
index 1d31b22b6d25ff..dbf4878622eba9 100644
--- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
+++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
@@ -64,39 +64,125 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
   return false;
 }
 
+static const CXXRecordDecl *getOptionalBaseClass(const CXXRecordDecl *RD) {
+  if (RD == nullptr)
+    return nullptr;
+  if (hasOptionalClassName(*RD))
+    return RD;
+
+  if (!RD->hasDefinition())
+    return nullptr;
+
+  for (const CXXBaseSpecifier &Base : RD->bases())
+    if (const CXXRecordDecl *BaseClass =
+            getOptionalBaseClass(Base.getType()->getAsCXXRecordDecl()))
+      return BaseClass;
+
+  return nullptr;
+}
+
 namespace {
 
 using namespace ::clang::ast_matchers;
 using LatticeTransferState = TransferState<NoopLattice>;
 
-AST_MATCHER(CXXRecordDecl, hasOptionalClassNameMatcher) {
-  return hasOptionalClassName(Node);
+AST_MATCHER(CXXRecordDecl, optionalClass) { return hasOptionalClassName(Node); }
+
+AST_MATCHER(CXXRecordDecl, optionalOrDerivedClass) {
+  return getOptionalBaseClass(&Node) != nullptr;
 }
 
-DeclarationMatcher optionalClass() {
-  return classTemplateSpecializationDecl(
-      hasOptionalClassNameMatcher(),
-      hasTemplateArgument(0, refersToType(type().bind("T"))));
+auto desugarsToOptionalType() {
+  return hasUnqualifiedDesugaredType(
+      recordType(hasDeclaration(cxxRecordDecl(optionalClass()))));
 }
 
-auto optionalOrAliasType() {
+auto desugarsToOptionalOrDerivedType() {
   return hasUnqualifiedDesugaredType(
-      recordType(hasDeclaration(optionalClass())));
+      recordType(hasDeclaration(cxxRecordDecl(optionalOrDerivedClass()))));
+}
+
+auto hasOptionalType() { return hasType(desugarsToOptionalType()); }
+
+/// Matches any of the spellings of the optional types and sugar, aliases,
+/// derived classes, etc.
+auto hasOptionalOrDerivedType() {
+  return hasType(desugarsToOptionalOrDerivedType());
+}
+
+QualType getPublicType(const Expr *E) {
+  auto *Cast = dyn_cast<ImplicitCastExpr>(E->IgnoreParens());
+  if (Cast == nullptr || Cast->getCastKind() != CK_UncheckedDerivedToBase) {
+    QualType Ty = E->getType();
+    if (Ty->isPointerType())
+      return Ty->getPointeeType();
+    return Ty;
+  }
+
+  // Is the derived type that we're casting from the type of `*this`? In this
+  // special case, we can upcast to the base class even if the base is
+  // non-public.
+  bool CastingFromThis = isa<CXXThisExpr>(Cast->getSubExpr());
+
+  // Find the least-derived type in the path (i.e. the last entry in the list)
+  // that we can access.
+  const CXXBaseSpecifier *PublicBase = nullptr;
+  for (const CXXBaseSpecifier *Base : Cast->path()) {
+    if (Base->getAccessSpecifier() != AS_public && !CastingFromThis)
+      break;
+    PublicBase = Base;
+    CastingFromThis = false;
+  }
+
+  if (PublicBase != nullptr)
+    return PublicBase->getType();
+
+  // We didn't find any public type that we could cast to. There may be more
+  // casts in `getSubExpr()`, so recurse. (If there aren't any more casts, this
+  // will return the type of `getSubExpr()`.)
+  return getPublicType(Cast->getSubExpr());
 }
 
-/// Matches any of the spellings of the optional types and sugar, aliases, etc.
-auto hasOptionalType() { return hasType(optionalOrAliasType()); }
+// Returns the least-derived type for the receiver of `MCE` that
+// `MCE.getImplicitObjectArgument()->IgnoreParentImpCasts()` can be downcast to.
+// Effectively, we upcast until we reach a non-public base class, unless that
+// base is a base of `*this`.
+//
+// This is needed to correctly match methods called on types derived from
+// `std::optional`.
+//
+// Say we have a `struct Derived : public std::optional<int> {} d;` For a call
+// `d.has_value()`, the `getImplicitObjectArgument()` looks like this:
+//
+//   ImplicitCastExpr 'const std::__optional_storage_base<int>' lvalue
+//   |            <UncheckedDerivedToBase (optional -> __optional_storage_base)>
+//   `-DeclRefExpr 'Derived' lvalue Var 'd' 'Derived'
+//
+// The type of the implicit object argument is `__optional_storage_base`
+// (since this is the internal type that `has_value()` is declared on). If we
+// call `IgnoreParenImpCasts()` on the implicit object argument, we get the
+// `DeclRefExpr`, which has type `Derived`. Neither of these types is
+// `optional`, and hence neither is sufficient for querying whether we are
+// calling a method on `optional`.
+//
+// Instead, starting with the most derived type, we need to follow the chain of
+// casts
+QualType getPublicReceiverType(const CXXMemberCallExpr &MCE) {
+  return getPublicType(MCE.getImplicitObjectArgument());
+}
+
+AST_MATCHER_P(CXXMemberCallExpr, publicReceiverType,
+              ast_matchers::internal::Matcher<QualType>, InnerMatcher) {
+  return InnerMatcher.matches(getPublicReceiverType(Node), Finder, Builder);
+}
 
 auto isOptionalMemberCallWithNameMatcher(
     ast_matchers::internal::Matcher<NamedDecl> matcher,
     const std::optional<StatementMatcher> &Ignorable = std::nullopt) {
-  auto Exception = unless(Ignorable ? expr(anyOf(*Ignorable, cxxThisExpr()))
-                                    : cxxThisExpr());
-  return cxxMemberCallExpr(
-      on(expr(Exception,
-              anyOf(hasOptionalType(),
-                    hasType(pointerType(pointee(optionalOrAliasType())))))),
-      callee(cxxMethodDecl(matcher)));
+  return cxxMemberCallExpr(Ignorable ? on(expr(unless(*Ignorable)))
+                                     : anything(),
+                           publicReceiverType(desugarsToOptionalType()),
+                           callee(cxxMethodDecl(matcher)));
 }
 
 auto isOptionalOperatorCallWithName(
@@ -129,49 +215,51 @@ auto inPlaceClass() {
 
 auto isOptionalNulloptConstructor() {
   return cxxConstructExpr(
-      hasOptionalType(),
       hasDeclaration(cxxConstructorDecl(parameterCountIs(1),
-                                        hasParameter(0, hasNulloptType()))));
+                                        hasParameter(0, hasNulloptType()))),
+      hasOptionalOrDerivedType());
 }
 
 auto isOptionalInPlaceConstructor() {
-  return cxxConstructExpr(hasOptionalType(),
-                          hasArgument(0, hasType(inPlaceClass())));
+  return cxxConstructExpr(hasArgument(0, hasType(inPlaceClass())),
+                          hasOptionalOrDerivedType());
 }
 
 auto isOptionalValueOrConversionConstructor() {
   return cxxConstructExpr(
-      hasOptionalType(),
       unless(hasDeclaration(
           cxxConstructorDecl(anyOf(isCopyConstructor(), isMoveConstructor())))),
-      argumentCountIs(1), hasArgument(0, unless(hasNulloptType())));
+      argumentCountIs(1), hasArgument(0, unless(hasNulloptType())),
+      hasOptionalOrDerivedType());
 }
 
 auto isOptionalValueOrConversionAssignment() {
   return cxxOperatorCallExpr(
       hasOverloadedOperatorName("="),
-      callee(cxxMethodDecl(ofClass(optionalClass()))),
+      callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
       unless(hasDeclaration(cxxMethodDecl(
           anyOf(isCopyAssignmentOperator(), isMoveAssignmentOperator())))),
       argumentCountIs(2), hasArgument(1, unless(hasNulloptType())));
 }
 
 auto isOptionalNulloptAssignment() {
-  return cxxOperatorCallExpr(hasOverloadedOperatorName("="),
-                             callee(cxxMethodDecl(ofClass(optionalClass()))),
-                             argumentCountIs(2),
-                             hasArgument(1, hasNulloptType()));
+  return cxxOperatorCallExpr(
+      hasOverloadedOperatorName("="),
+      callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
+      argumentCountIs(2), hasArgument(1, hasNulloptType()));
 }
 
 auto isStdSwapCall() {
   return callExpr(callee(functionDecl(hasName("std::swap"))),
-                  argumentCountIs(2), hasArgument(0, hasOptionalType()),
-                  hasArgument(1, hasOptionalType()));
+                  argumentCountIs(2),
+                  hasArgument(0, hasOptionalOrDerivedType()),
+                  hasArgument(1, hasOptionalOrDerivedType()));
 }
 
 auto isStdForwardCall() {
   return callExpr(callee(functionDecl(hasName("std::forward"))),
-                  argumentCountIs(1), hasArgument(0, hasOptionalType()));
+                  argumentCountIs(1),
+                  hasArgument(0, hasOptionalOrDerivedType()));
 }
 
 constexpr llvm::StringLiteral ValueOrCallID = "ValueOrCall";
@@ -212,8 +300,9 @@ auto isValueOrNotEqX() {
 }
 
 auto isCallReturningOptional() {
-  return callExpr(hasType(qualType(anyOf(
-      optionalOrAliasType(), referenceType(pointee(optionalOrAliasType()))))));
+  return callExpr(hasType(qualType(
+      anyOf(desugarsToOptionalOrDerivedType(),
+            referenceType(pointee(desugarsToOptionalOrDerivedType()))))));
 }
 
 template <typename L, typename R>
@@ -275,12 +364,9 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) {
   return HasValueVal;
 }
 
-/// Returns true if and only if `Type` is an optional type.
-bool isOptionalType(QualType Type) {
-  if (!Type->isRecordType())
-    return false;
-  const CXXRecordDecl *D = Type->getAsCXXRecordDecl();
-  return D != nullptr && hasOptionalClassName(*D);
+QualType valueTypeFromOptionalDecl(const CXXRecordDecl &RD) {
+  auto &CTSD = cast<ClassTemplateSpecializationDecl>(RD);
+  return CTSD.getTemplateArgs()[0].getAsType();
 }
 
 /// Returns the number of optional wrappers in `Type`.
@@ -288,15 +374,13 @@ bool isOptionalType(QualType Type) {
 /// For example, if `Type` is `optional<optional<int>>`, the result of this
 /// function will be 2.
 int countOptionalWrappers(const ASTContext &ASTCtx, QualType Type) {
-  if (!isOptionalType(Type))
+  const CXXRecordDecl *Optional =
+      getOptionalBaseClass(Type->getAsCXXRecordDecl());
+  if (Optional == nullptr)
     return 0;
   return 1 + countOptionalWrappers(
                  ASTCtx,
-                 cast<ClassTemplateSpecializationDecl>(Type->getAsRecordDecl())
-                     ->getTemplateArgs()
-                     .get(0)
-                     .getAsType()
-                     .getDesugaredType(ASTCtx));
+                 valueTypeFromOptionalDecl(*Optional).getDesugaredType(ASTCtx));
 }
 
 StorageLocation *getLocBehindPossiblePointer(const Expr &E,
@@ -843,13 +927,7 @@ auto buildDiagnoseMatchSwitch(
 
 ast_matchers::DeclarationMatcher
 UncheckedOptionalAccessModel::optionalClassDecl() {
-  return optionalClass();
-}
-
-static QualType valueTypeFromOptionalType(QualType OptionalTy) {
-  auto *CTSD =
-      cast<ClassTemplateSpecializationDecl>(OptionalTy->getAsCXXRecordDecl());
-  return CTSD->getTemplateArgs()[0].getAsType();
+  return cxxRecordDecl(optionalClass());
 }
 
 UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
@@ -858,9 +936,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
       TransferMatchSwitch(buildTransferMatchSwitch()) {
   Env.getDataflowAnalysisContext().setSyntheticFieldCallback(
       [&Ctx](QualType Ty) -> llvm::StringMap<QualType> {
-        if (!isOptionalType(Ty))
+        const CXXRecordDecl *Optional =
+            getOptionalBaseClass(Ty->getAsCXXRecordDecl());
+        if (Optional == nullptr)
           return {};
-        return {{"value", valueTypeFromOptionalType(Ty)},
+        return {{"value", valueTypeFromOptionalDecl(*Optional)},
                 {"has_value", Ctx.BoolTy}};
       });
 }

diff  --git a/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp b/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp
index b6e4973fd7cb2b..9430730004dbd2 100644
--- a/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp
+++ b/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp
@@ -3383,6 +3383,66 @@ TEST_P(UncheckedOptionalAccessTest, LambdaCaptureStateNotPropagated) {
     }
   )");
 }
+
+TEST_P(UncheckedOptionalAccessTest, ClassDerivedFromOptional) {
+  ExpectDiagnosticsFor(R"(
+    #include "unchecked_optional_access_test.h"
+
+    struct Derived : public $ns::$optional<int> {};
+
+    void target(Derived opt) {
+      *opt;  // [[unsafe]]
+      if (opt.has_value())
+        *opt;
+
+      // The same thing, but with a pointer receiver.
+      Derived *popt = &opt;
+      **popt;  // [[unsafe]]
+      if (popt->has_value())
+        **popt;
+    }
+  )");
+}
+
+TEST_P(UncheckedOptionalAccessTest, ClassTemplateDerivedFromOptional) {
+  ExpectDiagnosticsFor(R"(
+    #include "unchecked_optional_access_test.h"
+
+    template <class T>
+    struct Derived : public $ns::$optional<T> {};
+
+    void target(Derived<int> opt) {
+      *opt;  // [[unsafe]]
+      if (opt.has_value())
+        *opt;
+
+      // The same thing, but with a pointer receiver.
+      Derived<int> *popt = &opt;
+      **popt;  // [[unsafe]]
+      if (popt->has_value())
+        **popt;
+    }
+  )");
+}
+
+TEST_P(UncheckedOptionalAccessTest, ClassDerivedPrivatelyFromOptional) {
+  // Classes that derive privately from optional can themselves still call
+  // member functions of optional. Check that we model the optional correctly
+  // in this situation.
+  ExpectDiagnosticsFor(R"(
+    #include "unchecked_optional_access_test.h"
+
+    struct Derived : private $ns::$optional<int> {
+      void Method() {
+        **this;  // [[unsafe]]
+        if (this->has_value())
+          **this;
+      }
+    };
+  )",
+                       ast_matchers::hasName("Method"));
+}
+
 // FIXME: Add support for:
 // - constructors (copy, move)
 // - assignment operators (default, copy, move)


        


More information about the cfe-commits mailing list