[clang] [CodeGen] Revamp counted_by calculations (PR #70606)

Bill Wendling via cfe-commits cfe-commits at lists.llvm.org
Sun Oct 29 15:22:43 PDT 2023


https://github.com/bwendling updated https://github.com/llvm/llvm-project/pull/70606

>From 19dd7db8ab5f98a618c717944c96b34e604fbc30 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Sun, 29 Oct 2023 14:58:04 -0700
Subject: [PATCH 1/2] [CodeGen] Revamp counted_by calculations

Break down the counted_by calculations so that they correctly handle
anonymous structs, which are specified internally as IndirectFieldDecls.
Also simplify the code to use helper methods to get the field referenced
by counted_by and the flexible array member itself, which also had some
issues with FAMs in sub-structs.
---
 clang/lib/CodeGen/CGBuiltin.cpp     | 91 +++++++++++++++-------------
 clang/lib/CodeGen/CGExpr.cpp        | 93 +++++++++++++++++++++++------
 clang/lib/CodeGen/CodeGenFunction.h | 12 +++-
 3 files changed, 134 insertions(+), 62 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index dce5ee5888c458e..acee2c1af1ab368 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -859,53 +859,60 @@ CodeGenFunction::emitBuiltinObjectSize(const Expr *E, unsigned Type,
   }
 
   if (IsDynamic) {
-    LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel =
-        getLangOpts().getStrictFlexArraysLevel();
-    const Expr *Base = E->IgnoreParenImpCasts();
-
-    if (FieldDecl *FD = FindCountedByField(Base, StrictFlexArraysLevel)) {
-      const auto *ME = dyn_cast<MemberExpr>(Base);
-      llvm::Value *ObjectSize = nullptr;
-
-      if (!ME) {
-        const auto *DRE = dyn_cast<DeclRefExpr>(Base);
-        ValueDecl *VD = nullptr;
-
-        ObjectSize = ConstantInt::get(
-            ResType,
-            getContext().getTypeSize(DRE->getType()->getPointeeType()) / 8,
-            true);
-
-        if (auto *RD = DRE->getType()->getPointeeType()->getAsRecordDecl())
-          VD = RD->getLastField();
-
-        Expr *ICE = ImplicitCastExpr::Create(
-            getContext(), DRE->getType(), CK_LValueToRValue,
-            const_cast<Expr *>(cast<Expr>(DRE)), nullptr, VK_PRValue,
-            FPOptionsOverride());
-        ME = MemberExpr::CreateImplicit(getContext(), ICE, true, VD,
-                                        VD->getType(), VK_LValue, OK_Ordinary);
-      }
-
-      // At this point, we know that \p ME is a flexible array member.
-      const auto *ArrayTy = getContext().getAsArrayType(ME->getType());
+    // The code generated here calculates the size of a struct with a flexible
+    // array member that uses the counted_by attribute. There are two instances
+    // we handle:
+    //
+    //       struct s {
+    //         unsigned long flags;
+    //         int count;
+    //         int array[] __attribute__((counted_by(count)));
+    //       }
+    //
+    //   1) bdos of the flexible array itself:
+    //
+    //     __builtin_dynamic_object_size(p->array, 1) ==
+    //         p->count * sizeof(*p->array)
+    //
+    //   2) bdos of the whole struct, including the flexible array:
+    //
+    //     __builtin_dynamic_object_size(p, 1) ==
+    //        sizeof(*p) + p->count * sizeof(*p->array)
+    //
+    if (const ValueDecl *CountedByFD = FindCountedByField(E)) {
+      // Find the flexible array member.
+      const RecordDecl *OuterRD =
+        CountedByFD->getDeclContext()->getOuterLexicalRecordContext();
+      const ValueDecl *FAM = FindFlexibleArrayMemberField(getContext(),
+                                                          OuterRD);
+
+      // Get the size of the flexible array member's base type.
+      const auto *ArrayTy = getContext().getAsArrayType(FAM->getType());
       unsigned Size = getContext().getTypeSize(ArrayTy->getElementType());
 
-      llvm::Value *CountField =
-          EmitAnyExprToTemp(MemberExpr::CreateImplicit(
-                                getContext(), const_cast<Expr *>(ME->getBase()),
-                                ME->isArrow(), FD, FD->getType(), VK_LValue,
-                                OK_Ordinary))
-              .getScalarVal();
+      // Find the outer struct expr (i.e. p in p->a.b.c.d).
+      Expr *CountedByExpr = BuildCountedByFieldExpr(const_cast<Expr *>(E),
+                                                    CountedByFD);
+
+      llvm::Value *CountedByInstr =
+        EmitAnyExprToTemp(CountedByExpr).getScalarVal();
 
-      llvm::Value *Mul = Builder.CreateMul(
-          CountField, llvm::ConstantInt::get(CountField->getType(), Size / 8));
-      Mul = Builder.CreateZExtOrTrunc(Mul, ResType);
+      llvm::Constant *ArraySize =
+        llvm::ConstantInt::get(CountedByInstr->getType(), Size / 8);
 
-      if (ObjectSize)
-        return Builder.CreateAdd(ObjectSize, Mul);
+      llvm::Value *ObjectSize = Builder.CreateMul(CountedByInstr, ArraySize);
+      ObjectSize = Builder.CreateZExtOrTrunc(ObjectSize, ResType);
+
+      if (const auto *DRE = dyn_cast<DeclRefExpr>(E->IgnoreImpCasts())) {
+        // The whole struct is specificed in the __bdos.
+        QualType StructTy = DRE->getType()->getPointeeType();
+        llvm::Value *StructSize = ConstantInt::get(
+            ResType, getContext().getTypeSize(StructTy) / 8, true);
+        ObjectSize = Builder.CreateAdd(StructSize, ObjectSize);
+      }
 
-      return Mul;
+      // PULL THE STRING!!
+      return ObjectSize;
     }
   }
 
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 54a1d300a9ac738..2b39194e18ed861 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -944,14 +944,10 @@ static llvm::Value *getArrayIndexingBound(CodeGenFunction &CGF,
       // Ignore pass_object_size here. It's not applicable on decayed pointers.
     }
 
-    if (FieldDecl *FD = CGF.FindCountedByField(Base, StrictFlexArraysLevel)) {
-      const auto *ME = dyn_cast<MemberExpr>(CE->getSubExpr());
+    if (const ValueDecl *VD = CGF.FindCountedByField(Base)) {
       IndexedType = Base->getType();
-      return CGF
-          .EmitAnyExprToTemp(MemberExpr::CreateImplicit(
-              CGF.getContext(), const_cast<Expr *>(ME->getBase()),
-              ME->isArrow(), FD, FD->getType(), VK_LValue, OK_Ordinary))
-          .getScalarVal();
+      Expr *E = CGF.BuildCountedByFieldExpr(const_cast<Expr *>(Base), VD);
+      return CGF.EmitAnyExprToTemp(E).getScalarVal();
     }
   }
 
@@ -966,9 +962,68 @@ static llvm::Value *getArrayIndexingBound(CodeGenFunction &CGF,
   return nullptr;
 }
 
-FieldDecl *CodeGenFunction::FindCountedByField(
-    const Expr *Base,
-    LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel) {
+Expr *CodeGenFunction::BuildCountedByFieldExpr(Expr *Base,
+                                               const ValueDecl *CountedByVD) {
+  // Find the outer struct expr (i.e. p in p->a.b.c.d).
+  Base = Base->IgnoreImpCasts();
+  Base = Base->IgnoreParenNoopCasts(getContext());
+
+  // Work our way up the expression until we reach the DeclRefExpr.
+  while (!isa<DeclRefExpr>(Base))
+    if (auto *ME = dyn_cast<MemberExpr>(Base->IgnoreImpCasts())) {
+      Base = ME->getBase()->IgnoreImpCasts();
+      Base = Base->IgnoreParenNoopCasts(getContext());
+    }
+
+  // Add back an implicit cast to create the required pr-value.
+  Base = ImplicitCastExpr::Create(
+      getContext(), Base->getType(), CK_LValueToRValue, Base,
+      nullptr, VK_PRValue, FPOptionsOverride());
+
+  Expr *CountedByExpr = Base;
+
+  if (const auto *IFD = dyn_cast<IndirectFieldDecl>(CountedByVD)) {
+    // The counted_by field is inside an anonymous struct / union. The
+    // IndirectFieldDecl has the correct order of FieldDecls to build this
+    // easily. (Yay!)
+    for (NamedDecl *ND : IFD->chain()) {
+      ValueDecl *VD = cast<ValueDecl>(ND);
+      CountedByExpr = MemberExpr::CreateImplicit(
+          getContext(), CountedByExpr,
+          CountedByExpr->getType()->isPointerType(), VD, VD->getType(),
+          VK_LValue, OK_Ordinary);
+    }
+  } else {
+    CountedByExpr = MemberExpr::CreateImplicit(
+        getContext(), CountedByExpr,
+        CountedByExpr->getType()->isPointerType(),
+        const_cast<ValueDecl *>(CountedByVD), CountedByVD->getType(),
+        VK_LValue, OK_Ordinary);
+  }
+
+  return CountedByExpr;
+}
+
+const ValueDecl *CodeGenFunction::FindFlexibleArrayMemberField(
+    ASTContext &Ctx, const RecordDecl *RD) {
+  LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel =
+      getLangOpts().getStrictFlexArraysLevel();
+
+  for (const Decl *D : RD->decls()) {
+    if (const ValueDecl *VD = dyn_cast<ValueDecl>(D);
+        VD && Decl::isFlexibleArrayMemberLike(Ctx, VD, VD->getType(),
+                                              StrictFlexArraysLevel, true))
+      return VD;
+
+    if (const auto *Record = dyn_cast<RecordDecl>(D))
+      if (const ValueDecl *VD = FindFlexibleArrayMemberField(Ctx, Record))
+        return VD;
+  }
+
+  return nullptr;
+}
+
+const ValueDecl *CodeGenFunction::FindCountedByField(const Expr *Base) {
   const ValueDecl *VD = nullptr;
 
   Base = Base->IgnoreParenImpCasts();
@@ -984,12 +1039,14 @@ FieldDecl *CodeGenFunction::FindCountedByField(
       Ty = Ty->getPointeeType();
 
     if (const auto *RD = Ty->getAsRecordDecl())
-      VD = RD->getLastField();
+      VD = FindFlexibleArrayMemberField(getContext(), RD);
   } else if (const auto *CE = dyn_cast<CastExpr>(Base)) {
     if (const auto *ME = dyn_cast<MemberExpr>(CE->getSubExpr()))
       VD = dyn_cast<ValueDecl>(ME->getMemberDecl());
   }
 
+  LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel =
+      getLangOpts().getStrictFlexArraysLevel();
   const auto *FD = dyn_cast_if_present<FieldDecl>(VD);
   if (!FD || !FD->getParent() ||
       !Decl::isFlexibleArrayMemberLike(getContext(), FD, FD->getType(),
@@ -1000,12 +1057,14 @@ FieldDecl *CodeGenFunction::FindCountedByField(
   if (!CBA)
     return nullptr;
 
-  StringRef FieldName = CBA->getCountedByField()->getName();
-  auto It =
-      llvm::find_if(FD->getParent()->fields(), [&](const FieldDecl *Field) {
-        return FieldName == Field->getName();
-      });
-  return It != FD->getParent()->field_end() ? *It : nullptr;
+  const RecordDecl *RD = FD->getDeclContext()->getOuterLexicalRecordContext();
+  DeclarationName DName(CBA->getCountedByField());
+  DeclContext::lookup_result Lookup = RD->lookup(DName);
+
+  if (Lookup.empty())
+    return nullptr;
+
+  return dyn_cast<ValueDecl>(Lookup.front());
 }
 
 void CodeGenFunction::EmitBoundsCheck(const Expr *E, const Expr *Base,
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index e82115e2d706cf1..64f192037ec8ce5 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3022,11 +3022,17 @@ class CodeGenFunction : public CodeGenTypeCache {
   void EmitBoundsCheck(const Expr *E, const Expr *Base, llvm::Value *Index,
                        QualType IndexType, bool Accessed);
 
+  // Find a struct's flexible array member. It may be embedded inside multiple
+  // sub-structs, but must still be the last field.
+  const ValueDecl *FindFlexibleArrayMemberField(ASTContext &Ctx,
+                                                const RecordDecl *RD);
+
   /// Find the FieldDecl specified in a FAM's "counted_by" attribute. Returns
   /// \p nullptr if either the attribute or the field doesn't exist.
-  FieldDecl *FindCountedByField(
-      const Expr *Base,
-      LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel);
+  const ValueDecl *FindCountedByField(const Expr *Base);
+
+  /// Build an expression accessing the "counted_by" field.
+  Expr *BuildCountedByFieldExpr(Expr *Base, const ValueDecl *CountedByVD);
 
   llvm::Value *EmitScalarPrePostIncDec(const UnaryOperator *E, LValue LV,
                                        bool isInc, bool isPre);

>From 36b5271a7729c626a93a7fec9ff3bbd325436a02 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Sun, 29 Oct 2023 15:22:29 -0700
Subject: [PATCH 2/2] Reformat with clang-format

---
 clang/lib/CodeGen/CGBuiltin.cpp | 14 +++++++-------
 clang/lib/CodeGen/CGExpr.cpp    | 26 +++++++++++++-------------
 2 files changed, 20 insertions(+), 20 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index acee2c1af1ab368..26c73d07c7038e5 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -882,23 +882,23 @@ CodeGenFunction::emitBuiltinObjectSize(const Expr *E, unsigned Type,
     if (const ValueDecl *CountedByFD = FindCountedByField(E)) {
       // Find the flexible array member.
       const RecordDecl *OuterRD =
-        CountedByFD->getDeclContext()->getOuterLexicalRecordContext();
-      const ValueDecl *FAM = FindFlexibleArrayMemberField(getContext(),
-                                                          OuterRD);
+          CountedByFD->getDeclContext()->getOuterLexicalRecordContext();
+      const ValueDecl *FAM =
+          FindFlexibleArrayMemberField(getContext(), OuterRD);
 
       // Get the size of the flexible array member's base type.
       const auto *ArrayTy = getContext().getAsArrayType(FAM->getType());
       unsigned Size = getContext().getTypeSize(ArrayTy->getElementType());
 
       // Find the outer struct expr (i.e. p in p->a.b.c.d).
-      Expr *CountedByExpr = BuildCountedByFieldExpr(const_cast<Expr *>(E),
-                                                    CountedByFD);
+      Expr *CountedByExpr =
+          BuildCountedByFieldExpr(const_cast<Expr *>(E), CountedByFD);
 
       llvm::Value *CountedByInstr =
-        EmitAnyExprToTemp(CountedByExpr).getScalarVal();
+          EmitAnyExprToTemp(CountedByExpr).getScalarVal();
 
       llvm::Constant *ArraySize =
-        llvm::ConstantInt::get(CountedByInstr->getType(), Size / 8);
+          llvm::ConstantInt::get(CountedByInstr->getType(), Size / 8);
 
       llvm::Value *ObjectSize = Builder.CreateMul(CountedByInstr, ArraySize);
       ObjectSize = Builder.CreateZExtOrTrunc(ObjectSize, ResType);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 2b39194e18ed861..b39cd8b45fe8f6a 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -976,9 +976,9 @@ Expr *CodeGenFunction::BuildCountedByFieldExpr(Expr *Base,
     }
 
   // Add back an implicit cast to create the required pr-value.
-  Base = ImplicitCastExpr::Create(
-      getContext(), Base->getType(), CK_LValueToRValue, Base,
-      nullptr, VK_PRValue, FPOptionsOverride());
+  Base =
+      ImplicitCastExpr::Create(getContext(), Base->getType(), CK_LValueToRValue,
+                               Base, nullptr, VK_PRValue, FPOptionsOverride());
 
   Expr *CountedByExpr = Base;
 
@@ -988,24 +988,24 @@ Expr *CodeGenFunction::BuildCountedByFieldExpr(Expr *Base,
     // easily. (Yay!)
     for (NamedDecl *ND : IFD->chain()) {
       ValueDecl *VD = cast<ValueDecl>(ND);
-      CountedByExpr = MemberExpr::CreateImplicit(
-          getContext(), CountedByExpr,
-          CountedByExpr->getType()->isPointerType(), VD, VD->getType(),
-          VK_LValue, OK_Ordinary);
+      CountedByExpr =
+          MemberExpr::CreateImplicit(getContext(), CountedByExpr,
+                                     CountedByExpr->getType()->isPointerType(),
+                                     VD, VD->getType(), VK_LValue, OK_Ordinary);
     }
   } else {
     CountedByExpr = MemberExpr::CreateImplicit(
-        getContext(), CountedByExpr,
-        CountedByExpr->getType()->isPointerType(),
-        const_cast<ValueDecl *>(CountedByVD), CountedByVD->getType(),
-        VK_LValue, OK_Ordinary);
+        getContext(), CountedByExpr, CountedByExpr->getType()->isPointerType(),
+        const_cast<ValueDecl *>(CountedByVD), CountedByVD->getType(), VK_LValue,
+        OK_Ordinary);
   }
 
   return CountedByExpr;
 }
 
-const ValueDecl *CodeGenFunction::FindFlexibleArrayMemberField(
-    ASTContext &Ctx, const RecordDecl *RD) {
+const ValueDecl *
+CodeGenFunction::FindFlexibleArrayMemberField(ASTContext &Ctx,
+                                              const RecordDecl *RD) {
   LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel =
       getLangOpts().getStrictFlexArraysLevel();
 



More information about the cfe-commits mailing list