[clang] [clang][dataflow] Make optional checker work for types derived from optional. (PR #84138)

via cfe-commits cfe-commits at lists.llvm.org
Tue Mar 19 02:18:19 PDT 2024


https://github.com/martinboehme updated https://github.com/llvm/llvm-project/pull/84138

>From 209aac483514f6d041486f36e6dabec78598fcec Mon Sep 17 00:00:00 2001
From: Martin Braenne <mboehme at google.com>
Date: Tue, 12 Mar 2024 15:56:41 +0000
Subject: [PATCH 1/3] [clang][dataflow] Make optional checker work for types
 derived from optional.

`llvm::MaybeAlign` does this, for example.

It's not an option to simply ignore these derived classes because they get cast
back to the optional classes (for example, simply when calling the optional
member functions), and our transfer functions will then run on those optional
classes and therefore require them to be properly initialized.
---
 .../Models/UncheckedOptionalAccessModel.cpp   | 174 +++++++++++++-----
 .../UncheckedOptionalAccessModelTest.cpp      |  60 ++++++
 2 files changed, 183 insertions(+), 51 deletions(-)

diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
index 1d31b22b6d25ff..90b4b7d04187df 100644
--- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
+++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
@@ -64,39 +64,117 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
   return false;
 }
 
+static const CXXRecordDecl *getOptionalBaseClass(const CXXRecordDecl *RD) {
+  if (RD == nullptr)
+    return nullptr;
+  if (hasOptionalClassName(*RD))
+    return RD;
+
+  if (!RD->hasDefinition())
+    return nullptr;
+
+  for (const CXXBaseSpecifier &Base : RD->bases())
+    if (const CXXRecordDecl *BaseClass =
+            getOptionalBaseClass(Base.getType()->getAsCXXRecordDecl()))
+      return BaseClass;
+
+  return nullptr;
+}
+
 namespace {
 
 using namespace ::clang::ast_matchers;
 using LatticeTransferState = TransferState<NoopLattice>;
 
-AST_MATCHER(CXXRecordDecl, hasOptionalClassNameMatcher) {
-  return hasOptionalClassName(Node);
+AST_MATCHER(CXXRecordDecl, optionalClass) { return hasOptionalClassName(Node); }
+
+AST_MATCHER(CXXRecordDecl, optionalOrDerivedClass) {
+  return getOptionalBaseClass(&Node) != nullptr;
 }
 
-DeclarationMatcher optionalClass() {
-  return classTemplateSpecializationDecl(
-      hasOptionalClassNameMatcher(),
-      hasTemplateArgument(0, refersToType(type().bind("T"))));
+auto desugarsToOptionalType() {
+  return hasUnqualifiedDesugaredType(
+      recordType(hasDeclaration(cxxRecordDecl(optionalClass()))));
 }
 
-auto optionalOrAliasType() {
+auto desugarsToOptionalOrDerivedType() {
   return hasUnqualifiedDesugaredType(
-      recordType(hasDeclaration(optionalClass())));
+      recordType(hasDeclaration(cxxRecordDecl(optionalOrDerivedClass()))));
 }
 
-/// Matches any of the spellings of the optional types and sugar, aliases, etc.
-auto hasOptionalType() { return hasType(optionalOrAliasType()); }
+auto hasOptionalType() { return hasType(desugarsToOptionalType()); }
+
+/// Matches any of the spellings of the optional types and sugar, aliases,
+/// derived classes, etc.
+auto hasOptionalOrDerivedType() {
+  return hasType(desugarsToOptionalOrDerivedType());
+}
+
+QualType getPublicType(const Expr *E) {
+  auto *Cast = dyn_cast<ImplicitCastExpr>(E->IgnoreParens());
+  if (Cast == nullptr || Cast->getCastKind() != CK_UncheckedDerivedToBase) {
+    QualType Ty = E->getType();
+    if (Ty->isPointerType())
+      return Ty->getPointeeType();
+    return Ty;
+  }
+
+  QualType Ty = getPublicType(Cast->getSubExpr());
+
+  // Is `Ty` the type of `*this`? In this special case, we can upcast to the
+  // base class even if the base is non-public.
+  bool TyIsThisType = isa<CXXThisExpr>(Cast->getSubExpr());
+
+  for (const CXXBaseSpecifier *Base : Cast->path()) {
+    if (Base->getAccessSpecifier() != AS_public && !TyIsThisType)
+      break;
+    Ty = Base->getType();
+    TyIsThisType = false;
+  }
+
+  return Ty;
+}
+
+// Returns the least-derived type for the receiver of `MCE` that
+// `MCE.getImplicitObjectArgument()->IgnoreParentImpCasts()` can be downcast to.
+// Effectively, we upcast until we reach a non-public base class, unless that
+// base is a base of `*this`.
+//
+// This is needed to correctly match methods called on types derived from
+// `std::optional`.
+//
+// Say we have a `struct Derived : public std::optional<int> {} d;` For a call
+// `d.has_value()`, the `getImplicitObjectArgument()` looks like this:
+//
+//   ImplicitCastExpr 'const std::__optional_storage_base<int>' lvalue
+//   |            <UncheckedDerivedToBase (optional -> __optional_storage_base)>
+//   `-DeclRefExpr 'Derived' lvalue Var 'd' 'Derived'
+//
+// The type of the implicit object argument is `__optional_storage_base`
+// (since this is the internal type that `has_value()` is declared on). If we
+// call `IgnoreParenImpCasts()` on the implicit object argument, we get the
+// `DeclRefExpr`, which has type `Derived`. Neither of these types is
+// `optional`, and hence neither is sufficient for querying whether we are
+// calling a method on `optional`.
+//
+// Instead, starting with the most derived type, we need to follow the chain of
+// casts
+QualType getPublicReceiverType(const CXXMemberCallExpr &MCE) {
+  return getPublicType(MCE.getImplicitObjectArgument());
+}
+
+AST_MATCHER_P(CXXMemberCallExpr, publicReceiverType,
+              ast_matchers::internal::Matcher<QualType>, InnerMatcher) {
+  return InnerMatcher.matches(getPublicReceiverType(Node), Finder, Builder);
+}
 
 auto isOptionalMemberCallWithNameMatcher(
     ast_matchers::internal::Matcher<NamedDecl> matcher,
     const std::optional<StatementMatcher> &Ignorable = std::nullopt) {
-  auto Exception = unless(Ignorable ? expr(anyOf(*Ignorable, cxxThisExpr()))
-                                    : cxxThisExpr());
-  return cxxMemberCallExpr(
-      on(expr(Exception,
-              anyOf(hasOptionalType(),
-                    hasType(pointerType(pointee(optionalOrAliasType())))))),
-      callee(cxxMethodDecl(matcher)));
+  return cxxMemberCallExpr(Ignorable ? on(expr(unless(*Ignorable)))
+                                     : anything(),
+                           publicReceiverType(desugarsToOptionalType()),
+                           callee(cxxMethodDecl(matcher)));
 }
 
 auto isOptionalOperatorCallWithName(
@@ -129,19 +207,19 @@ auto inPlaceClass() {
 
 auto isOptionalNulloptConstructor() {
   return cxxConstructExpr(
-      hasOptionalType(),
+      hasOptionalOrDerivedType(),
       hasDeclaration(cxxConstructorDecl(parameterCountIs(1),
                                         hasParameter(0, hasNulloptType()))));
 }
 
 auto isOptionalInPlaceConstructor() {
-  return cxxConstructExpr(hasOptionalType(),
+  return cxxConstructExpr(hasOptionalOrDerivedType(),
                           hasArgument(0, hasType(inPlaceClass())));
 }
 
 auto isOptionalValueOrConversionConstructor() {
   return cxxConstructExpr(
-      hasOptionalType(),
+      hasOptionalOrDerivedType(),
       unless(hasDeclaration(
           cxxConstructorDecl(anyOf(isCopyConstructor(), isMoveConstructor())))),
       argumentCountIs(1), hasArgument(0, unless(hasNulloptType())));
@@ -150,28 +228,30 @@ auto isOptionalValueOrConversionConstructor() {
 auto isOptionalValueOrConversionAssignment() {
   return cxxOperatorCallExpr(
       hasOverloadedOperatorName("="),
-      callee(cxxMethodDecl(ofClass(optionalClass()))),
+      callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
       unless(hasDeclaration(cxxMethodDecl(
           anyOf(isCopyAssignmentOperator(), isMoveAssignmentOperator())))),
       argumentCountIs(2), hasArgument(1, unless(hasNulloptType())));
 }
 
 auto isOptionalNulloptAssignment() {
-  return cxxOperatorCallExpr(hasOverloadedOperatorName("="),
-                             callee(cxxMethodDecl(ofClass(optionalClass()))),
-                             argumentCountIs(2),
-                             hasArgument(1, hasNulloptType()));
+  return cxxOperatorCallExpr(
+      hasOverloadedOperatorName("="),
+      callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
+      argumentCountIs(2), hasArgument(1, hasNulloptType()));
 }
 
 auto isStdSwapCall() {
   return callExpr(callee(functionDecl(hasName("std::swap"))),
-                  argumentCountIs(2), hasArgument(0, hasOptionalType()),
-                  hasArgument(1, hasOptionalType()));
+                  argumentCountIs(2),
+                  hasArgument(0, hasOptionalOrDerivedType()),
+                  hasArgument(1, hasOptionalOrDerivedType()));
 }
 
 auto isStdForwardCall() {
   return callExpr(callee(functionDecl(hasName("std::forward"))),
-                  argumentCountIs(1), hasArgument(0, hasOptionalType()));
+                  argumentCountIs(1),
+                  hasArgument(0, hasOptionalOrDerivedType()));
 }
 
 constexpr llvm::StringLiteral ValueOrCallID = "ValueOrCall";
@@ -212,8 +292,9 @@ auto isValueOrNotEqX() {
 }
 
 auto isCallReturningOptional() {
-  return callExpr(hasType(qualType(anyOf(
-      optionalOrAliasType(), referenceType(pointee(optionalOrAliasType()))))));
+  return callExpr(hasType(qualType(
+      anyOf(desugarsToOptionalOrDerivedType(),
+            referenceType(pointee(desugarsToOptionalOrDerivedType()))))));
 }
 
 template <typename L, typename R>
@@ -275,12 +356,9 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) {
   return HasValueVal;
 }
 
-/// Returns true if and only if `Type` is an optional type.
-bool isOptionalType(QualType Type) {
-  if (!Type->isRecordType())
-    return false;
-  const CXXRecordDecl *D = Type->getAsCXXRecordDecl();
-  return D != nullptr && hasOptionalClassName(*D);
+QualType valueTypeFromOptionalDecl(const CXXRecordDecl &RD) {
+  auto &CTSD = cast<ClassTemplateSpecializationDecl>(RD);
+  return CTSD.getTemplateArgs()[0].getAsType();
 }
 
 /// Returns the number of optional wrappers in `Type`.
@@ -288,15 +366,13 @@ bool isOptionalType(QualType Type) {
 /// For example, if `Type` is `optional<optional<int>>`, the result of this
 /// function will be 2.
 int countOptionalWrappers(const ASTContext &ASTCtx, QualType Type) {
-  if (!isOptionalType(Type))
+  const CXXRecordDecl *Optional =
+      getOptionalBaseClass(Type->getAsCXXRecordDecl());
+  if (Optional == nullptr)
     return 0;
   return 1 + countOptionalWrappers(
                  ASTCtx,
-                 cast<ClassTemplateSpecializationDecl>(Type->getAsRecordDecl())
-                     ->getTemplateArgs()
-                     .get(0)
-                     .getAsType()
-                     .getDesugaredType(ASTCtx));
+                 valueTypeFromOptionalDecl(*Optional).getDesugaredType(ASTCtx));
 }
 
 StorageLocation *getLocBehindPossiblePointer(const Expr &E,
@@ -843,13 +919,7 @@ auto buildDiagnoseMatchSwitch(
 
 ast_matchers::DeclarationMatcher
 UncheckedOptionalAccessModel::optionalClassDecl() {
-  return optionalClass();
-}
-
-static QualType valueTypeFromOptionalType(QualType OptionalTy) {
-  auto *CTSD =
-      cast<ClassTemplateSpecializationDecl>(OptionalTy->getAsCXXRecordDecl());
-  return CTSD->getTemplateArgs()[0].getAsType();
+  return cxxRecordDecl(optionalClass());
 }
 
 UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
@@ -858,9 +928,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
       TransferMatchSwitch(buildTransferMatchSwitch()) {
   Env.getDataflowAnalysisContext().setSyntheticFieldCallback(
       [&Ctx](QualType Ty) -> llvm::StringMap<QualType> {
-        if (!isOptionalType(Ty))
+        const CXXRecordDecl *Optional =
+            getOptionalBaseClass(Ty->getAsCXXRecordDecl());
+        if (Optional == nullptr)
           return {};
-        return {{"value", valueTypeFromOptionalType(Ty)},
+        return {{"value", valueTypeFromOptionalDecl(*Optional)},
                 {"has_value", Ctx.BoolTy}};
       });
 }
diff --git a/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp b/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp
index b6e4973fd7cb2b..9430730004dbd2 100644
--- a/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp
+++ b/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp
@@ -3383,6 +3383,66 @@ TEST_P(UncheckedOptionalAccessTest, LambdaCaptureStateNotPropagated) {
     }
   )");
 }
+
+TEST_P(UncheckedOptionalAccessTest, ClassDerivedFromOptional) {
+  ExpectDiagnosticsFor(R"(
+    #include "unchecked_optional_access_test.h"
+
+    struct Derived : public $ns::$optional<int> {};
+
+    void target(Derived opt) {
+      *opt;  // [[unsafe]]
+      if (opt.has_value())
+        *opt;
+
+      // The same thing, but with a pointer receiver.
+      Derived *popt = &opt;
+      **popt;  // [[unsafe]]
+      if (popt->has_value())
+        **popt;
+    }
+  )");
+}
+
+TEST_P(UncheckedOptionalAccessTest, ClassTemplateDerivedFromOptional) {
+  ExpectDiagnosticsFor(R"(
+    #include "unchecked_optional_access_test.h"
+
+    template <class T>
+    struct Derived : public $ns::$optional<T> {};
+
+    void target(Derived<int> opt) {
+      *opt;  // [[unsafe]]
+      if (opt.has_value())
+        *opt;
+
+      // The same thing, but with a pointer receiver.
+      Derived<int> *popt = &opt;
+      **popt;  // [[unsafe]]
+      if (popt->has_value())
+        **popt;
+    }
+  )");
+}
+
+TEST_P(UncheckedOptionalAccessTest, ClassDerivedPrivatelyFromOptional) {
+  // Classes that derive privately from optional can themselves still call
+  // member functions of optional. Check that we model the optional correctly
+  // in this situation.
+  ExpectDiagnosticsFor(R"(
+    #include "unchecked_optional_access_test.h"
+
+    struct Derived : private $ns::$optional<int> {
+      void Method() {
+        **this;  // [[unsafe]]
+        if (this->has_value())
+          **this;
+      }
+    };
+  )",
+                       ast_matchers::hasName("Method"));
+}
+
 // FIXME: Add support for:
 // - constructors (copy, move)
 // - assignment operators (default, copy, move)

>From 05e4114c61b7a6e69f1bd517aa11ed00007485ca Mon Sep 17 00:00:00 2001
From: Martin Braenne <mboehme at google.com>
Date: Mon, 18 Mar 2024 10:27:19 +0000
Subject: [PATCH 2/3] fixup! [clang][dataflow] Make optional checker work for
 types derived from optional.

---
 .../Models/UncheckedOptionalAccessModel.cpp   | 26 ++++++++++++-------
 1 file changed, 17 insertions(+), 9 deletions(-)

diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
index 90b4b7d04187df..0d36464338af8b 100644
--- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
+++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
@@ -119,20 +119,28 @@ QualType getPublicType(const Expr *E) {
     return Ty;
   }
 
-  QualType Ty = getPublicType(Cast->getSubExpr());
-
-  // Is `Ty` the type of `*this`? In this special case, we can upcast to the
-  // base class even if the base is non-public.
-  bool TyIsThisType = isa<CXXThisExpr>(Cast->getSubExpr());
-
+  // Is the derived type that we're casting from the type of `*this`? In this
+  // special case, we can upcast to the base class even if the base is
+  // non-public.
+  bool CastingFromThis = isa<CXXThisExpr>(Cast->getSubExpr());
+
+  // Find the least-derived type in the path (i.e. the last entry in the list)
+  // that we can access.
+  QualType Ty;
   for (const CXXBaseSpecifier *Base : Cast->path()) {
-    if (Base->getAccessSpecifier() != AS_public && !TyIsThisType)
+    if (Base->getAccessSpecifier() != AS_public && !CastingFromThis)
       break;
     Ty = Base->getType();
-    TyIsThisType = false;
+    CastingFromThis = false;
   }
 
-  return Ty;
+  if (!Ty.isNull())
+    return Ty;
+
+  // We didn't find any public type that we could cast to. There may be more
+  // casts in `getSubExpr()`, so recurse. (If there aren't any more casts, this
+  // will return the type of `getSubExpr()`.)
+  return getPublicType(Cast->getSubExpr());
 }
 
 // Returns the least-derived type for the receiver of `MCE` that

>From 4e3d47b06d7f59723816fd4e624d3dacfbc558a3 Mon Sep 17 00:00:00 2001
From: Martin Braenne <mboehme at google.com>
Date: Tue, 19 Mar 2024 09:17:56 +0000
Subject: [PATCH 3/3] fixup! fixup! [clang][dataflow] Make optional checker
 work for types derived from optional.

---
 .../Models/UncheckedOptionalAccessModel.cpp   | 20 +++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
index 0d36464338af8b..dbf4878622eba9 100644
--- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
+++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
@@ -126,16 +126,16 @@ QualType getPublicType(const Expr *E) {
 
   // Find the least-derived type in the path (i.e. the last entry in the list)
   // that we can access.
-  QualType Ty;
+  const CXXBaseSpecifier *PublicBase = nullptr;
   for (const CXXBaseSpecifier *Base : Cast->path()) {
     if (Base->getAccessSpecifier() != AS_public && !CastingFromThis)
       break;
-    Ty = Base->getType();
+    PublicBase = Base;
     CastingFromThis = false;
   }
 
-  if (!Ty.isNull())
-    return Ty;
+  if (PublicBase != nullptr)
+    return PublicBase->getType();
 
   // We didn't find any public type that we could cast to. There may be more
   // casts in `getSubExpr()`, so recurse. (If there aren't any more casts, this
@@ -215,22 +215,22 @@ auto inPlaceClass() {
 
 auto isOptionalNulloptConstructor() {
   return cxxConstructExpr(
-      hasOptionalOrDerivedType(),
       hasDeclaration(cxxConstructorDecl(parameterCountIs(1),
-                                        hasParameter(0, hasNulloptType()))));
+                                        hasParameter(0, hasNulloptType()))),
+      hasOptionalOrDerivedType());
 }
 
 auto isOptionalInPlaceConstructor() {
-  return cxxConstructExpr(hasOptionalOrDerivedType(),
-                          hasArgument(0, hasType(inPlaceClass())));
+  return cxxConstructExpr(hasArgument(0, hasType(inPlaceClass())),
+                          hasOptionalOrDerivedType());
 }
 
 auto isOptionalValueOrConversionConstructor() {
   return cxxConstructExpr(
-      hasOptionalOrDerivedType(),
       unless(hasDeclaration(
           cxxConstructorDecl(anyOf(isCopyConstructor(), isMoveConstructor())))),
-      argumentCountIs(1), hasArgument(0, unless(hasNulloptType())));
+      argumentCountIs(1), hasArgument(0, unless(hasNulloptType())),
+      hasOptionalOrDerivedType());
 }
 
 auto isOptionalValueOrConversionAssignment() {



More information about the cfe-commits mailing list