[clang] [clang][dataflow] Make optional checker work for types derived from optional. (PR #84138)

via cfe-commits cfe-commits at lists.llvm.org
Mon Mar 18 03:53:16 PDT 2024


================
@@ -64,39 +64,117 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
   return false;
 }
 
+static const CXXRecordDecl *getOptionalBaseClass(const CXXRecordDecl *RD) {
+  if (RD == nullptr)
+    return nullptr;
+  if (hasOptionalClassName(*RD))
+    return RD;
+
+  if (!RD->hasDefinition())
+    return nullptr;
+
+  for (const CXXBaseSpecifier &Base : RD->bases())
+    if (const CXXRecordDecl *BaseClass =
+            getOptionalBaseClass(Base.getType()->getAsCXXRecordDecl()))
+      return BaseClass;
+
+  return nullptr;
+}
+
 namespace {
 
 using namespace ::clang::ast_matchers;
 using LatticeTransferState = TransferState<NoopLattice>;
 
-AST_MATCHER(CXXRecordDecl, hasOptionalClassNameMatcher) {
-  return hasOptionalClassName(Node);
+AST_MATCHER(CXXRecordDecl, optionalClass) { return hasOptionalClassName(Node); }
+
+AST_MATCHER(CXXRecordDecl, optionalOrDerivedClass) {
+  return getOptionalBaseClass(&Node) != nullptr;
 }
 
-DeclarationMatcher optionalClass() {
-  return classTemplateSpecializationDecl(
-      hasOptionalClassNameMatcher(),
-      hasTemplateArgument(0, refersToType(type().bind("T"))));
+auto desugarsToOptionalType() {
+  return hasUnqualifiedDesugaredType(
+      recordType(hasDeclaration(cxxRecordDecl(optionalClass()))));
 }
 
-auto optionalOrAliasType() {
+auto desugarsToOptionalOrDerivedType() {
   return hasUnqualifiedDesugaredType(
-      recordType(hasDeclaration(optionalClass())));
+      recordType(hasDeclaration(cxxRecordDecl(optionalOrDerivedClass()))));
 }
 
-/// Matches any of the spellings of the optional types and sugar, aliases, etc.
-auto hasOptionalType() { return hasType(optionalOrAliasType()); }
+auto hasOptionalType() { return hasType(desugarsToOptionalType()); }
+
+/// Matches any of the spellings of the optional types and sugar, aliases,
+/// derived classes, etc.
+auto hasOptionalOrDerivedType() {
+  return hasType(desugarsToOptionalOrDerivedType());
+}
+
+QualType getPublicType(const Expr *E) {
+  auto *Cast = dyn_cast<ImplicitCastExpr>(E->IgnoreParens());
+  if (Cast == nullptr || Cast->getCastKind() != CK_UncheckedDerivedToBase) {
+    QualType Ty = E->getType();
+    if (Ty->isPointerType())
+      return Ty->getPointeeType();
+    return Ty;
+  }
+
+  QualType Ty = getPublicType(Cast->getSubExpr());
----------------
martinboehme wrote:

I hope this has become clearer given the additional comments on what the loop is doing.

The idea behind doing the recursion first, then potentially throwing the result away, is that it makes the code proceed strictly along the patch from most-derived to most-base class.

Conceptually, we want to "drill down" into the `getImplicitObjectArgument()` to find the most-derived type, then walk back along the "cast path" towards the base until we find a point where we're no longer allowed to perform the cast. The code that was here expresses this most directly (the recursive call does the "drilling down"), but as you note, it's inefficient in the case where we need to throw away the result of the recursive call.

I've changed this to use tail recursion and have added comments that hopefully make it clear what the recursion is doing.

https://github.com/llvm/llvm-project/pull/84138


More information about the cfe-commits mailing list