[clang] [Clang] Add __builtin_get_counted_by builtin (PR #102549)

Bill Wendling via cfe-commits cfe-commits at lists.llvm.org
Wed Aug 14 14:04:12 PDT 2024


https://github.com/bwendling updated https://github.com/llvm/llvm-project/pull/102549

>From 7ba43ae2b737fbd868848a23b58b3965f8d36ce1 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Tue, 6 Aug 2024 17:49:01 -0700
Subject: [PATCH 01/11] [WIP][Clang] Add __builtin_get_counted_by builtin

The __builtin_get_counted_by builtin is used on a flexible array
pointer and returns a pointer to the "counted_by" attribute's COUNT
argument, which is a field in the same non-anonymous struct as the
flexible array member. This is useful for automatically setting the
count field without needing the programmer's intervention. Otherwise
it's possible to get this anti-pattern:

  ptr = alloc(<ty>, COUNT);
  ptr->FAM[9] = 37; /* <<< Sanitizer will complain */
  ptr->count = COUNT;
---
 clang/include/clang/Basic/Builtins.td |  6 ++++
 clang/lib/CodeGen/CGBuiltin.cpp       | 22 ++++++++++++
 clang/lib/CodeGen/CGExpr.cpp          | 29 +++++++++-------
 clang/lib/CodeGen/CodeGenFunction.h   |  4 +++
 clang/lib/Sema/SemaExpr.cpp           | 49 +++++++++++++++++++++++++--
 5 files changed, 96 insertions(+), 14 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..254cd157d5f9d0 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4774,3 +4774,9 @@ def ArithmeticFence : LangBuiltin<"ALL_LANGUAGES"> {
   let Attributes = [CustomTypeChecking, Constexpr];
   let Prototype = "void(...)";
 }
+
+def GetCountedBy : Builtin {
+  let Spellings = ["__builtin_get_counted_by"];
+  let Attributes = [NoThrow];
+  let Prototype = "size_t*(void*)";
+}
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 51d1162c6e403c..859249be72c07e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3563,6 +3563,28 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     return RValue::get(emitBuiltinObjectSize(E->getArg(0), Type, ResType,
                                              /*EmittedE=*/nullptr, IsDynamic));
   }
+  case Builtin::BI__builtin_get_counted_by: {
+    llvm::Value *Result = nullptr;
+
+    if (const MemberExpr *ME =
+            dyn_cast<MemberExpr>(E->getArg(0)->IgnoreImpCasts())) {
+      bool IsFlexibleArrayMember = ME->isFlexibleArrayMemberLike(
+              getContext(), getLangOpts().getStrictFlexArraysLevel(),
+              /*IgnoreTemplateOrMacroSubstitution=*/false);
+
+      // TODO: Probably have to handle horrible casting crap here.
+
+      // FIXME: Emit a diagnostic?
+      if (!ME->HasSideEffects(getContext()) && IsFlexibleArrayMember &&
+          ME->getMemberDecl()->getType()->isCountAttributedType()) {
+        const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
+        if (const FieldDecl *CountFD = FindCountedByField(FAMDecl))
+          Result = GetCountedByFieldExprGEP(ME, FAMDecl, CountFD);
+      }
+    }
+
+    return RValue::get(Result);
+  }
   case Builtin::BI__builtin_prefetch: {
     Value *Locality, *RW, *Address = EmitScalarExpr(E->getArg(0));
     // FIXME: Technically these constants should of type 'int', yes?
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index a1dce741c78a11..55cd95c08e3ffc 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1100,15 +1100,7 @@ static bool getGEPIndicesToField(CodeGenFunction &CGF, const RecordDecl *RD,
   return false;
 }
 
-/// This method is typically called in contexts where we can't generate
-/// side-effects, like in __builtin_dynamic_object_size. When finding
-/// expressions, only choose those that have either already been emitted or can
-/// be loaded without side-effects.
-///
-/// - \p FAMDecl: the \p Decl for the flexible array member. It may not be
-///   within the top-level struct.
-/// - \p CountDecl: must be within the same non-anonymous struct as \p FAMDecl.
-llvm::Value *CodeGenFunction::EmitLoadOfCountedByField(
+llvm::Value *CodeGenFunction::GetCountedByFieldExprGEP(
     const Expr *Base, const FieldDecl *FAMDecl, const FieldDecl *CountDecl) {
   const RecordDecl *RD = CountDecl->getParent()->getOuterLexicalRecordContext();
 
@@ -1141,12 +1133,25 @@ llvm::Value *CodeGenFunction::EmitLoadOfCountedByField(
     return nullptr;
 
   Indices.push_back(Builder.getInt32(0));
-  Res = Builder.CreateInBoundsGEP(
+  return Builder.CreateInBoundsGEP(
       ConvertType(QualType(RD->getTypeForDecl(), 0)), Res,
       RecIndicesTy(llvm::reverse(Indices)), "..counted_by.gep");
+}
 
-  return Builder.CreateAlignedLoad(ConvertType(CountDecl->getType()), Res,
-                                   getIntAlign(), "..counted_by.load");
+/// This method is typically called in contexts where we can't generate
+/// side-effects, like in __builtin_dynamic_object_size. When finding
+/// expressions, only choose those that have either already been emitted or can
+/// be loaded without side-effects.
+///
+/// - \p FAMDecl: the \p Decl for the flexible array member. It may not be
+///   within the top-level struct.
+/// - \p CountDecl: must be within the same non-anonymous struct as \p FAMDecl.
+llvm::Value *CodeGenFunction::EmitLoadOfCountedByField(
+    const Expr *Base, const FieldDecl *FAMDecl, const FieldDecl *CountDecl) {
+  if (llvm::Value *GEP = GetCountedByFieldExprGEP(Base, FAMDecl, CountDecl))
+    return Builder.CreateAlignedLoad(ConvertType(CountDecl->getType()), GEP,
+                                     getIntAlign(), "..counted_by.load");
+  return nullptr;
 }
 
 const FieldDecl *CodeGenFunction::FindCountedByField(const FieldDecl *FD) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 1c0a0e117e5607..e5f5b94bba54b0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3309,6 +3309,10 @@ class CodeGenFunction : public CodeGenTypeCache {
   /// \p nullptr if either the attribute or the field doesn't exist.
   const FieldDecl *FindCountedByField(const FieldDecl *FD);
 
+  llvm::Value *GetCountedByFieldExprGEP(const Expr *Base,
+                                        const FieldDecl *FAMDecl,
+                                        const FieldDecl *CountDecl);
+
   /// Build an expression accessing the "counted_by" field.
   llvm::Value *EmitLoadOfCountedByField(const Expr *Base,
                                         const FieldDecl *FAMDecl,
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 21c8ae6bad0eae..048d0c4ae49726 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -6390,9 +6390,26 @@ ExprResult Sema::ActOnCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
       currentEvaluationContext().ReferenceToConsteval.erase(DRE);
     }
   }
+
   return Call;
 }
 
+const FieldDecl *FindCountedByField(const FieldDecl *FD) {
+  if (!FD)
+    return nullptr;
+
+  const auto *CAT = FD->getType()->getAs<CountAttributedType>();
+  if (!CAT)
+    return nullptr;
+
+  const auto *CountDRE = cast<DeclRefExpr>(CAT->getCountExpr());
+  const auto *CountDecl = CountDRE->getDecl();
+  if (const auto *IFD = dyn_cast<IndirectFieldDecl>(CountDecl))
+    CountDecl = IFD->getAnonField();
+
+  return dyn_cast<FieldDecl>(CountDecl);
+}
+
 ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
                                MultiExprArg ArgExprs, SourceLocation RParenLoc,
                                Expr *ExecConfig, bool IsExecConfig,
@@ -6590,8 +6607,36 @@ ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
     return CallExpr::Create(Context, Fn, ArgExprs, Context.DependentTy,
                             VK_PRValue, RParenLoc, CurFPFeatureOverrides());
   }
-  return BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc,
-                               ExecConfig, IsExecConfig);
+
+  Result = BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc,
+                                 ExecConfig, IsExecConfig);
+
+  if (FunctionDecl *FDecl = dyn_cast_or_null<FunctionDecl>(NDecl);
+      FDecl && FDecl->getBuiltinID() == Builtin::BI__builtin_get_counted_by) {
+    if (const MemberExpr *ME =
+            dyn_cast<MemberExpr>(ArgExprs[0]->IgnoreImpCasts())) {
+      bool IsFlexibleArrayMember = ME->isFlexibleArrayMemberLike(
+              Context, getLangOpts().getStrictFlexArraysLevel(),
+              /*IgnoreTemplateOrMacroSubstitution=*/false);
+
+      // TODO: Probably have to handle horrible casting crap here.
+
+      // FIXME: Emit a diagnostic?
+      if (!ME->HasSideEffects(Context) && IsFlexibleArrayMember &&
+          ME->getMemberDecl()->getType()->isCountAttributedType()) {
+        const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
+        if (const FieldDecl *CountFD = FindCountedByField(FAMDecl)) {
+          QualType PtrTy = Context.getPointerType(CountFD->getType());
+          Result = CStyleCastExpr::Create(
+              Context, PtrTy, VK_LValue, CK_BitCast, Result.get(), nullptr,
+              FPOptionsOverride(), Context.CreateTypeSourceInfo(PtrTy),
+              LParenLoc, RParenLoc);
+        }
+      }
+    }
+  }
+
+  return Result;
 }
 
 Expr *Sema::BuildBuiltinCallExpr(SourceLocation Loc, Builtin::ID Id,

>From 3c42b388f42c2570e6ce7b168601f6f69a959ea5 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Wed, 7 Aug 2024 10:32:28 -0700
Subject: [PATCH 02/11] If the 'counted_by' attribute doesn't exist, then
 return nullptr.

---
 clang/lib/CodeGen/CGBuiltin.cpp | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 859249be72c07e..bf8fe007f979f3 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3580,6 +3580,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
         if (const FieldDecl *CountFD = FindCountedByField(FAMDecl))
           Result = GetCountedByFieldExprGEP(ME, FAMDecl, CountFD);
+      } else {
+        E->getType()->dump();
+        ConvertType(E->getType())->dump();
+        Result = llvm::ConstantPointerNull::get(cast<llvm::PointerType>(ConvertType(E->getType())));
       }
     }
 

>From 23efe69d0b279b9e14467d75ea48e1d0b1831a23 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Wed, 7 Aug 2024 10:36:01 -0700
Subject: [PATCH 03/11] Remove debugging code.

---
 clang/lib/CodeGen/CGBuiltin.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index bf8fe007f979f3..07d46b2f74de38 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3581,8 +3581,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         if (const FieldDecl *CountFD = FindCountedByField(FAMDecl))
           Result = GetCountedByFieldExprGEP(ME, FAMDecl, CountFD);
       } else {
-        E->getType()->dump();
-        ConvertType(E->getType())->dump();
         Result = llvm::ConstantPointerNull::get(cast<llvm::PointerType>(ConvertType(E->getType())));
       }
     }

>From 1c693069bda9576c7159c188003f321a6d5ed216 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Wed, 7 Aug 2024 10:38:42 -0700
Subject: [PATCH 04/11] Make returning a nullptr the default.

---
 clang/lib/CodeGen/CGBuiltin.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 07d46b2f74de38..b92fe7a6469f42 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3564,7 +3564,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
                                              /*EmittedE=*/nullptr, IsDynamic));
   }
   case Builtin::BI__builtin_get_counted_by: {
-    llvm::Value *Result = nullptr;
+    llvm::Value *Result =
+        llvm::ConstantPointerNull::get(cast<llvm::PointerType>(ConvertType(E->getType())));
 
     if (const MemberExpr *ME =
             dyn_cast<MemberExpr>(E->getArg(0)->IgnoreImpCasts())) {
@@ -3580,8 +3581,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
         if (const FieldDecl *CountFD = FindCountedByField(FAMDecl))
           Result = GetCountedByFieldExprGEP(ME, FAMDecl, CountFD);
-      } else {
-        Result = llvm::ConstantPointerNull::get(cast<llvm::PointerType>(ConvertType(E->getType())));
       }
     }
 

>From c8c40044fbded04020b3b2759b8ff007ae54d043 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Wed, 7 Aug 2024 10:39:05 -0700
Subject: [PATCH 05/11] clang-format

---
 clang/lib/CodeGen/CGBuiltin.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index b92fe7a6469f42..a703d1173b9be2 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3564,8 +3564,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
                                              /*EmittedE=*/nullptr, IsDynamic));
   }
   case Builtin::BI__builtin_get_counted_by: {
-    llvm::Value *Result =
-        llvm::ConstantPointerNull::get(cast<llvm::PointerType>(ConvertType(E->getType())));
+    llvm::Value *Result = llvm::ConstantPointerNull::get(
+        cast<llvm::PointerType>(ConvertType(E->getType())));
 
     if (const MemberExpr *ME =
             dyn_cast<MemberExpr>(E->getArg(0)->IgnoreImpCasts())) {

>From 4532b8e295c426b687a622b60b7177451300f2b8 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Wed, 7 Aug 2024 13:54:22 -0700
Subject: [PATCH 06/11] Use a visitor to retrieve the MemberExpr. It's
 intentionally harsh and ignores pretty much everythings that may 'hide' the
 underlying MemberExpr.

---
 clang/lib/CodeGen/CGBuiltin.cpp | 47 +++++++++++++++++++++++++++----
 clang/lib/Sema/SemaExpr.cpp     | 50 +++++++++++++++++++++++++++++----
 2 files changed, 87 insertions(+), 10 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a703d1173b9be2..2cba1971280b7f 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -27,6 +27,7 @@
 #include "clang/AST/Decl.h"
 #include "clang/AST/OSLog.h"
 #include "clang/AST/OperationKinds.h"
+#include "clang/AST/StmtVisitor.h"
 #include "clang/Basic/TargetBuiltins.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Basic/TargetOptions.h"
@@ -2536,6 +2537,45 @@ static RValue EmitHipStdParUnsupportedBuiltin(CodeGenFunction *CGF,
   return RValue::get(CGF->Builder.CreateCall(UBF, Args));
 }
 
+namespace {
+
+/// MemberExprVisitor - Find the MemberExpr through all of the casts, array
+/// subscripts, and unary ops. This intentionally avoids all of them because
+/// we're interested only in the MemberExpr to check if it's a flexible array
+/// member.
+class MemberExprVisitor
+    : public ConstStmtVisitor<MemberExprVisitor, const Expr *> {
+public:
+  //===--------------------------------------------------------------------===//
+  //                            Visitor Methods
+  //===--------------------------------------------------------------------===//
+
+  const Expr *Visit(const Expr *E) {
+    return ConstStmtVisitor<MemberExprVisitor, const Expr *>::Visit(E);
+  }
+  const Expr *VisitStmt(const Stmt *S) { return nullptr; }
+
+  const Expr *VisitMemberExpr(const MemberExpr *E) { return E; }
+
+  const Expr *VisitArraySubscriptExpr(const ArraySubscriptExpr *E) {
+    return Visit(E->getBase());
+  }
+  const Expr *VisitCastExpr(const CastExpr *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitParenExpr(const ParenExpr *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitUnaryAddrOf(const clang::UnaryOperator *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitUnaryDeref(const clang::UnaryOperator *E) {
+    return Visit(E->getSubExpr());
+  }
+};
+
+} // anonymous namespace
+
 RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
                                         const CallExpr *E,
                                         ReturnValueSlot ReturnValue) {
@@ -3567,15 +3607,12 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     llvm::Value *Result = llvm::ConstantPointerNull::get(
         cast<llvm::PointerType>(ConvertType(E->getType())));
 
-    if (const MemberExpr *ME =
-            dyn_cast<MemberExpr>(E->getArg(0)->IgnoreImpCasts())) {
+    if (const Expr *Ptr = MemberExprVisitor().Visit(E->getArg(0))) {
+      const MemberExpr *ME = cast<MemberExpr>(Ptr);
       bool IsFlexibleArrayMember = ME->isFlexibleArrayMemberLike(
               getContext(), getLangOpts().getStrictFlexArraysLevel(),
               /*IgnoreTemplateOrMacroSubstitution=*/false);
 
-      // TODO: Probably have to handle horrible casting crap here.
-
-      // FIXME: Emit a diagnostic?
       if (!ME->HasSideEffects(getContext()) && IsFlexibleArrayMember &&
           ME->getMemberDecl()->getType()->isCountAttributedType()) {
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 048d0c4ae49726..8091af764f9b19 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -28,6 +28,7 @@
 #include "clang/AST/OperationKinds.h"
 #include "clang/AST/ParentMapContext.h"
 #include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/AST/StmtVisitor.h"
 #include "clang/AST/Type.h"
 #include "clang/AST/TypeLoc.h"
 #include "clang/Basic/Builtins.h"
@@ -6410,6 +6411,45 @@ const FieldDecl *FindCountedByField(const FieldDecl *FD) {
   return dyn_cast<FieldDecl>(CountDecl);
 }
 
+namespace {
+
+/// MemberExprVisitor - Find the MemberExpr through all of the casts, array
+/// subscripts, and unary ops. This intentionally avoids all of them because
+/// we're interested only in the MemberExpr to check if it's a flexible array
+/// member.
+class MemberExprVisitor
+    : public ConstStmtVisitor<MemberExprVisitor, const Expr *> {
+public:
+  //===--------------------------------------------------------------------===//
+  //                            Visitor Methods
+  //===--------------------------------------------------------------------===//
+
+  const Expr *Visit(const Expr *E) {
+    return ConstStmtVisitor<MemberExprVisitor, const Expr *>::Visit(E);
+  }
+  const Expr *VisitStmt(const Stmt *S) { return nullptr; }
+
+  const Expr *VisitMemberExpr(const MemberExpr *E) { return E; }
+
+  const Expr *VisitArraySubscriptExpr(const ArraySubscriptExpr *E) {
+    return Visit(E->getBase());
+  }
+  const Expr *VisitCastExpr(const CastExpr *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitParenExpr(const ParenExpr *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitUnaryAddrOf(const UnaryOperator *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitUnaryDeref(const UnaryOperator *E) {
+    return Visit(E->getSubExpr());
+  }
+};
+
+} // anonymous namespace
+
 ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
                                MultiExprArg ArgExprs, SourceLocation RParenLoc,
                                Expr *ExecConfig, bool IsExecConfig,
@@ -6613,19 +6653,19 @@ ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
 
   if (FunctionDecl *FDecl = dyn_cast_or_null<FunctionDecl>(NDecl);
       FDecl && FDecl->getBuiltinID() == Builtin::BI__builtin_get_counted_by) {
-    if (const MemberExpr *ME =
-            dyn_cast<MemberExpr>(ArgExprs[0]->IgnoreImpCasts())) {
+    if (const Expr *Ptr = MemberExprVisitor().Visit(ArgExprs[0])) {
+      const MemberExpr *ME = cast<MemberExpr>(Ptr);
       bool IsFlexibleArrayMember = ME->isFlexibleArrayMemberLike(
               Context, getLangOpts().getStrictFlexArraysLevel(),
               /*IgnoreTemplateOrMacroSubstitution=*/false);
 
-      // TODO: Probably have to handle horrible casting crap here.
-
-      // FIXME: Emit a diagnostic?
       if (!ME->HasSideEffects(Context) && IsFlexibleArrayMember &&
           ME->getMemberDecl()->getType()->isCountAttributedType()) {
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
         if (const FieldDecl *CountFD = FindCountedByField(FAMDecl)) {
+          // The builtin returns a 'size_t *', however 'size_t' might not be
+          // the type of the count field. Thus we create an explicit c-style
+          // cast to ensure the proper types going forward.
           QualType PtrTy = Context.getPointerType(CountFD->getType());
           Result = CStyleCastExpr::Create(
               Context, PtrTy, VK_LValue, CK_BitCast, Result.get(), nullptr,

>From 5a74079039c39202711e0cc20c3412a05367d129 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Wed, 7 Aug 2024 15:11:45 -0700
Subject: [PATCH 07/11] Add testcases.

---
 clang/test/CodeGen/builtin-get-counted-by.c | 83 +++++++++++++++++++++
 clang/test/Sema/builtin-get-counted-by.c    | 22 ++++++
 2 files changed, 105 insertions(+)
 create mode 100644 clang/test/CodeGen/builtin-get-counted-by.c
 create mode 100644 clang/test/Sema/builtin-get-counted-by.c

diff --git a/clang/test/CodeGen/builtin-get-counted-by.c b/clang/test/CodeGen/builtin-get-counted-by.c
new file mode 100644
index 00000000000000..8209db6a77111e
--- /dev/null
+++ b/clang/test/CodeGen/builtin-get-counted-by.c
@@ -0,0 +1,83 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// RUN: %clang_cc1 -triple x86_64-unknown-unknown -O2 -emit-llvm -o - %s | FileCheck %s --check-prefix=X86_64
+// RUN: %clang_cc1 -triple i386-unknown-unknown -O2 -emit-llvm -o - %s | FileCheck %s --check-prefix=I386
+
+struct s {
+  char x;
+  short count;
+  int array[] __attribute__((counted_by(count)));
+};
+
+// X86_64-LABEL: define dso_local noalias noundef ptr @test1(
+// X86_64-SAME: i32 noundef [[SIZE:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// X86_64-NEXT:  [[ENTRY:.*:]]
+// X86_64-NEXT:    [[CONV:%.*]] = sext i32 [[SIZE]] to i64
+// X86_64-NEXT:    [[MUL:%.*]] = shl nsw i64 [[CONV]], 2
+// X86_64-NEXT:    [[ADD:%.*]] = add nsw i64 [[MUL]], 4
+// X86_64-NEXT:    [[CALL:%.*]] = tail call ptr @malloc(i64 noundef [[ADD]]) #[[ATTR3:[0-9]+]]
+// X86_64-NEXT:    [[CONV1:%.*]] = trunc i32 [[SIZE]] to i16
+// X86_64-NEXT:    [[DOT_COUNTED_BY_GEP:%.*]] = getelementptr inbounds i8, ptr [[CALL]], i64 2
+// X86_64-NEXT:    store i16 [[CONV1]], ptr [[DOT_COUNTED_BY_GEP]], align 2, !tbaa [[TBAA2:![0-9]+]]
+// X86_64-NEXT:    ret ptr [[CALL]]
+//
+// I386-LABEL: define dso_local noalias noundef ptr @test1(
+// I386-SAME: i32 noundef [[SIZE:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
+// I386-NEXT:  [[ENTRY:.*:]]
+// I386-NEXT:    [[MUL:%.*]] = shl i32 [[SIZE]], 2
+// I386-NEXT:    [[ADD:%.*]] = add i32 [[MUL]], 4
+// I386-NEXT:    [[CALL:%.*]] = tail call ptr @malloc(i32 noundef [[ADD]]) #[[ATTR3:[0-9]+]]
+// I386-NEXT:    [[CONV:%.*]] = trunc i32 [[SIZE]] to i16
+// I386-NEXT:    [[DOT_COUNTED_BY_GEP:%.*]] = getelementptr inbounds i8, ptr [[CALL]], i32 2
+// I386-NEXT:    store i16 [[CONV]], ptr [[DOT_COUNTED_BY_GEP]], align 2, !tbaa [[TBAA3:![0-9]+]]
+// I386-NEXT:    ret ptr [[CALL]]
+//
+struct s *test1(int size) {
+  struct s *p = __builtin_malloc(sizeof(struct s) + sizeof(int) * size);
+
+  *__builtin_get_counted_by(p->array) = size;
+  *__builtin_get_counted_by(&p->array[0]) = size;
+  return p;
+}
+
+struct z {
+  char x;
+  short count;
+  int array[];
+};
+
+// X86_64-LABEL: define dso_local noalias noundef ptr @test2(
+// X86_64-SAME: i32 noundef [[SIZE:%.*]]) local_unnamed_addr #[[ATTR2:[0-9]+]] {
+// X86_64-NEXT:  [[ENTRY:.*:]]
+// X86_64-NEXT:    [[CONV:%.*]] = sext i32 [[SIZE]] to i64
+// X86_64-NEXT:    [[MUL:%.*]] = shl nsw i64 [[CONV]], 2
+// X86_64-NEXT:    [[ADD:%.*]] = add nsw i64 [[MUL]], 4
+// X86_64-NEXT:    [[CALL:%.*]] = tail call ptr @malloc(i64 noundef [[ADD]]) #[[ATTR3]]
+// X86_64-NEXT:    ret ptr [[CALL]]
+//
+// I386-LABEL: define dso_local noalias noundef ptr @test2(
+// I386-SAME: i32 noundef [[SIZE:%.*]]) local_unnamed_addr #[[ATTR2:[0-9]+]] {
+// I386-NEXT:  [[ENTRY:.*:]]
+// I386-NEXT:    [[MUL:%.*]] = shl i32 [[SIZE]], 2
+// I386-NEXT:    [[ADD:%.*]] = add i32 [[MUL]], 4
+// I386-NEXT:    [[CALL:%.*]] = tail call ptr @malloc(i32 noundef [[ADD]]) #[[ATTR3]]
+// I386-NEXT:    ret ptr [[CALL]]
+//
+struct z *test2(int size) {
+  struct z *p = __builtin_malloc(sizeof(struct z) + sizeof(int) * size);
+
+  if (__builtin_get_counted_by(&p->array[0]))
+    *__builtin_get_counted_by(&p->array[0]) = size;
+
+  return p;
+}
+//.
+// X86_64: [[TBAA2]] = !{[[META3:![0-9]+]], [[META3]], i64 0}
+// X86_64: [[META3]] = !{!"short", [[META4:![0-9]+]], i64 0}
+// X86_64: [[META4]] = !{!"omnipotent char", [[META5:![0-9]+]], i64 0}
+// X86_64: [[META5]] = !{!"Simple C/C++ TBAA"}
+//.
+// I386: [[TBAA3]] = !{[[META4:![0-9]+]], [[META4]], i64 0}
+// I386: [[META4]] = !{!"short", [[META5:![0-9]+]], i64 0}
+// I386: [[META5]] = !{!"omnipotent char", [[META6:![0-9]+]], i64 0}
+// I386: [[META6]] = !{!"Simple C/C++ TBAA"}
+//.
diff --git a/clang/test/Sema/builtin-get-counted-by.c b/clang/test/Sema/builtin-get-counted-by.c
new file mode 100644
index 00000000000000..18cef35b0509a1
--- /dev/null
+++ b/clang/test/Sema/builtin-get-counted-by.c
@@ -0,0 +1,22 @@
+// RUN: %clang_cc1 -fsyntax-only -verify %s
+
+struct fam_struct {
+  char x;
+  short count;
+  int array[] __attribute__((counted_by(count)));
+} *p;
+
+struct non_fam_struct {
+  char x;
+  short count;
+  int array[];
+} *q;
+
+void foo(int size) {
+  *__builtin_get_counted_by(p->array) = size;
+
+  if (__builtin_get_counted_by(q->array))
+    *__builtin_get_counted_by(q->array) = size;
+
+  *__builtin_get_counted_by(p->count) = size; // expected-error{{incompatible integer to pointer conversion passing 'short' to parameter of type 'void *'}}
+}

>From d9e9c6a879bf5fae9ec4453d41fe40abbeb865f9 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Thu, 8 Aug 2024 16:21:35 -0700
Subject: [PATCH 08/11] Reformat

---
 clang/lib/CodeGen/CGBuiltin.cpp | 4 ++--
 clang/lib/Sema/SemaExpr.cpp     | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 58fc0dfe45e1b0..de46c538fdcc30 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -3610,8 +3610,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     if (const Expr *Ptr = MemberExprVisitor().Visit(E->getArg(0))) {
       const MemberExpr *ME = cast<MemberExpr>(Ptr);
       bool IsFlexibleArrayMember = ME->isFlexibleArrayMemberLike(
-              getContext(), getLangOpts().getStrictFlexArraysLevel(),
-              /*IgnoreTemplateOrMacroSubstitution=*/false);
+          getContext(), getLangOpts().getStrictFlexArraysLevel(),
+          /*IgnoreTemplateOrMacroSubstitution=*/false);
 
       if (!ME->HasSideEffects(getContext()) && IsFlexibleArrayMember &&
           ME->getMemberDecl()->getType()->isCountAttributedType()) {
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 33bc71d621ddd5..3b80bcfd3ab26d 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -6656,8 +6656,8 @@ ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
     if (const Expr *Ptr = MemberExprVisitor().Visit(ArgExprs[0])) {
       const MemberExpr *ME = cast<MemberExpr>(Ptr);
       bool IsFlexibleArrayMember = ME->isFlexibleArrayMemberLike(
-              Context, getLangOpts().getStrictFlexArraysLevel(),
-              /*IgnoreTemplateOrMacroSubstitution=*/false);
+          Context, getLangOpts().getStrictFlexArraysLevel(),
+          /*IgnoreTemplateOrMacroSubstitution=*/false);
 
       if (!ME->HasSideEffects(Context) && IsFlexibleArrayMember &&
           ME->getMemberDecl()->getType()->isCountAttributedType()) {

>From 10efe640a00520e2d074e7acfb65cb956a8ce1ac Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Thu, 8 Aug 2024 16:51:24 -0700
Subject: [PATCH 09/11] Combine the code to grab the MemberExpr, if it exists.

---
 clang/include/clang/AST/Expr.h  | 10 +++++--
 clang/lib/AST/Expr.cpp          | 43 ++++++++++++++++++++++++++++++
 clang/lib/CodeGen/CGBuiltin.cpp | 46 ++-------------------------------
 clang/lib/Sema/SemaExpr.cpp     | 46 ++-------------------------------
 4 files changed, 55 insertions(+), 90 deletions(-)

diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 5b813bfc2faf90..98524ffd7c02db 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -537,8 +537,8 @@ class Expr : public ValueStmt {
   /// semantically correspond to a bool.
   bool isKnownToHaveBooleanValue(bool Semantic = true) const;
 
-  /// Check whether this array fits the idiom of a flexible array member,
-  /// depending on the value of -fstrict-flex-array.
+  /// isFlexibleArrayMemberLike - Check whether this array fits the idiom of a
+  /// flexible array member, depending on the value of -fstrict-flex-array.
   /// When IgnoreTemplateOrMacroSubstitution is set, it doesn't consider sizes
   /// resulting from the substitution of a macro or a template as special sizes.
   bool isFlexibleArrayMemberLike(
@@ -546,6 +546,12 @@ class Expr : public ValueStmt {
       LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel,
       bool IgnoreTemplateOrMacroSubstitution = false) const;
 
+  /// getMemberExpr - Find the first MemberExpr of the Expr. This method
+  /// intentionally looks through all casts, array subscripts, and unary
+  /// operators to find an underlying MemberExpr. If one doesn't exist, it
+  /// returns a nullptr.
+  const MemberExpr *getMemberExpr() const;
+
   /// isIntegerConstantExpr - Return the value if this expression is a valid
   /// integer constant expression.  If not a valid i-c-e, return std::nullopt
   /// and fill in Loc (if specified) with the location of the invalid
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 9d5b8167d0ee62..3a5599fac1d156 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -222,6 +222,49 @@ bool Expr::isFlexibleArrayMemberLike(
                                          IgnoreTemplateOrMacroSubstitution);
 }
 
+namespace {
+
+/// MemberExprVisitor - Find the MemberExpr through all of the casts, array
+/// subscripts, and unary ops. This intentionally avoids all of them because
+/// we're interested only in the MemberExpr to check if it's a flexible array
+/// member.
+class MemberExprVisitor
+    : public ConstStmtVisitor<MemberExprVisitor, const Expr *> {
+public:
+  //===--------------------------------------------------------------------===//
+  //                            Visitor Methods
+  //===--------------------------------------------------------------------===//
+
+  const Expr *Visit(const Expr *E) {
+    return ConstStmtVisitor<MemberExprVisitor, const Expr *>::Visit(E);
+  }
+  const Expr *VisitStmt(const Stmt *S) { return nullptr; }
+
+  const Expr *VisitMemberExpr(const MemberExpr *E) { return E; }
+
+  const Expr *VisitArraySubscriptExpr(const ArraySubscriptExpr *E) {
+    return Visit(E->getBase());
+  }
+  const Expr *VisitCastExpr(const CastExpr *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitParenExpr(const ParenExpr *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitUnaryAddrOf(const UnaryOperator *E) {
+    return Visit(E->getSubExpr());
+  }
+  const Expr *VisitUnaryDeref(const UnaryOperator *E) {
+    return Visit(E->getSubExpr());
+  }
+};
+
+} // anonymous namespace
+
+const MemberExpr *Expr::getMemberExpr() const {
+  return dyn_cast_if_present<MemberExpr>(MemberExprVisitor().Visit(this));
+}
+
 const ValueDecl *
 Expr::getAsBuiltinConstantDeclRef(const ASTContext &Context) const {
   Expr::EvalResult Eval;
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index de46c538fdcc30..95dd32eeee10cd 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -27,7 +27,6 @@
 #include "clang/AST/Decl.h"
 #include "clang/AST/OSLog.h"
 #include "clang/AST/OperationKinds.h"
-#include "clang/AST/StmtVisitor.h"
 #include "clang/Basic/TargetBuiltins.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Basic/TargetOptions.h"
@@ -2537,45 +2536,6 @@ static RValue EmitHipStdParUnsupportedBuiltin(CodeGenFunction *CGF,
   return RValue::get(CGF->Builder.CreateCall(UBF, Args));
 }
 
-namespace {
-
-/// MemberExprVisitor - Find the MemberExpr through all of the casts, array
-/// subscripts, and unary ops. This intentionally avoids all of them because
-/// we're interested only in the MemberExpr to check if it's a flexible array
-/// member.
-class MemberExprVisitor
-    : public ConstStmtVisitor<MemberExprVisitor, const Expr *> {
-public:
-  //===--------------------------------------------------------------------===//
-  //                            Visitor Methods
-  //===--------------------------------------------------------------------===//
-
-  const Expr *Visit(const Expr *E) {
-    return ConstStmtVisitor<MemberExprVisitor, const Expr *>::Visit(E);
-  }
-  const Expr *VisitStmt(const Stmt *S) { return nullptr; }
-
-  const Expr *VisitMemberExpr(const MemberExpr *E) { return E; }
-
-  const Expr *VisitArraySubscriptExpr(const ArraySubscriptExpr *E) {
-    return Visit(E->getBase());
-  }
-  const Expr *VisitCastExpr(const CastExpr *E) {
-    return Visit(E->getSubExpr());
-  }
-  const Expr *VisitParenExpr(const ParenExpr *E) {
-    return Visit(E->getSubExpr());
-  }
-  const Expr *VisitUnaryAddrOf(const clang::UnaryOperator *E) {
-    return Visit(E->getSubExpr());
-  }
-  const Expr *VisitUnaryDeref(const clang::UnaryOperator *E) {
-    return Visit(E->getSubExpr());
-  }
-};
-
-} // anonymous namespace
-
 RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
                                         const CallExpr *E,
                                         ReturnValueSlot ReturnValue) {
@@ -3607,11 +3567,9 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
     llvm::Value *Result = llvm::ConstantPointerNull::get(
         cast<llvm::PointerType>(ConvertType(E->getType())));
 
-    if (const Expr *Ptr = MemberExprVisitor().Visit(E->getArg(0))) {
-      const MemberExpr *ME = cast<MemberExpr>(Ptr);
+    if (const MemberExpr *ME = E->getArg(0)->getMemberExpr()) {
       bool IsFlexibleArrayMember = ME->isFlexibleArrayMemberLike(
-          getContext(), getLangOpts().getStrictFlexArraysLevel(),
-          /*IgnoreTemplateOrMacroSubstitution=*/false);
+          getContext(), getLangOpts().getStrictFlexArraysLevel());
 
       if (!ME->HasSideEffects(getContext()) && IsFlexibleArrayMember &&
           ME->getMemberDecl()->getType()->isCountAttributedType()) {
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 3b80bcfd3ab26d..45f540e417cd43 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -28,7 +28,6 @@
 #include "clang/AST/OperationKinds.h"
 #include "clang/AST/ParentMapContext.h"
 #include "clang/AST/RecursiveASTVisitor.h"
-#include "clang/AST/StmtVisitor.h"
 #include "clang/AST/Type.h"
 #include "clang/AST/TypeLoc.h"
 #include "clang/Basic/Builtins.h"
@@ -6411,45 +6410,6 @@ const FieldDecl *FindCountedByField(const FieldDecl *FD) {
   return dyn_cast<FieldDecl>(CountDecl);
 }
 
-namespace {
-
-/// MemberExprVisitor - Find the MemberExpr through all of the casts, array
-/// subscripts, and unary ops. This intentionally avoids all of them because
-/// we're interested only in the MemberExpr to check if it's a flexible array
-/// member.
-class MemberExprVisitor
-    : public ConstStmtVisitor<MemberExprVisitor, const Expr *> {
-public:
-  //===--------------------------------------------------------------------===//
-  //                            Visitor Methods
-  //===--------------------------------------------------------------------===//
-
-  const Expr *Visit(const Expr *E) {
-    return ConstStmtVisitor<MemberExprVisitor, const Expr *>::Visit(E);
-  }
-  const Expr *VisitStmt(const Stmt *S) { return nullptr; }
-
-  const Expr *VisitMemberExpr(const MemberExpr *E) { return E; }
-
-  const Expr *VisitArraySubscriptExpr(const ArraySubscriptExpr *E) {
-    return Visit(E->getBase());
-  }
-  const Expr *VisitCastExpr(const CastExpr *E) {
-    return Visit(E->getSubExpr());
-  }
-  const Expr *VisitParenExpr(const ParenExpr *E) {
-    return Visit(E->getSubExpr());
-  }
-  const Expr *VisitUnaryAddrOf(const UnaryOperator *E) {
-    return Visit(E->getSubExpr());
-  }
-  const Expr *VisitUnaryDeref(const UnaryOperator *E) {
-    return Visit(E->getSubExpr());
-  }
-};
-
-} // anonymous namespace
-
 ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
                                MultiExprArg ArgExprs, SourceLocation RParenLoc,
                                Expr *ExecConfig, bool IsExecConfig,
@@ -6653,11 +6613,9 @@ ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
 
   if (FunctionDecl *FDecl = dyn_cast_or_null<FunctionDecl>(NDecl);
       FDecl && FDecl->getBuiltinID() == Builtin::BI__builtin_get_counted_by) {
-    if (const Expr *Ptr = MemberExprVisitor().Visit(ArgExprs[0])) {
-      const MemberExpr *ME = cast<MemberExpr>(Ptr);
+    if (const MemberExpr *ME = ArgExprs[0]->getMemberExpr()) {
       bool IsFlexibleArrayMember = ME->isFlexibleArrayMemberLike(
-          Context, getLangOpts().getStrictFlexArraysLevel(),
-          /*IgnoreTemplateOrMacroSubstitution=*/false);
+          Context, getLangOpts().getStrictFlexArraysLevel());
 
       if (!ME->HasSideEffects(Context) && IsFlexibleArrayMember &&
           ME->getMemberDecl()->getType()->isCountAttributedType()) {

>From ad8fb9433cb935eec394dd7960d5bde52c792679 Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Fri, 9 Aug 2024 16:38:50 -0700
Subject: [PATCH 10/11] Move FindCountedByField into a central place.

---
 clang/include/clang/AST/Decl.h      |  4 ++++
 clang/lib/AST/Decl.cpp              | 13 +++++++++++++
 clang/lib/CodeGen/CGBuiltin.cpp     |  4 ++--
 clang/lib/CodeGen/CGExpr.cpp        | 18 +-----------------
 clang/lib/CodeGen/CodeGenFunction.h |  4 ----
 clang/lib/Sema/SemaExpr.cpp         | 19 +------------------
 6 files changed, 21 insertions(+), 41 deletions(-)

diff --git a/clang/include/clang/AST/Decl.h b/clang/include/clang/AST/Decl.h
index 561a9d872acfb0..5535703ef906f4 100644
--- a/clang/include/clang/AST/Decl.h
+++ b/clang/include/clang/AST/Decl.h
@@ -3206,6 +3206,10 @@ class FieldDecl : public DeclaratorDecl, public Mergeable<FieldDecl> {
   /// Set the C++11 in-class initializer for this member.
   void setInClassInitializer(Expr *NewInit);
 
+  /// Find the FieldDecl specified in a FAM's "counted_by" attribute. Returns
+  /// \p nullptr if either the attribute or the field doesn't exist.
+  const FieldDecl *FindCountedByField() const;
+
 private:
   void setLazyInClassInitializer(LazyDeclStmtPtr NewInit);
 
diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp
index d832ce4190ff1a..3f4cd1339b3205 100644
--- a/clang/lib/AST/Decl.cpp
+++ b/clang/lib/AST/Decl.cpp
@@ -4678,6 +4678,19 @@ void FieldDecl::printName(raw_ostream &OS, const PrintingPolicy &Policy) const {
   DeclaratorDecl::printName(OS, Policy);
 }
 
+const FieldDecl *FieldDecl::FindCountedByField() const {
+  const auto *CAT = getType()->getAs<CountAttributedType>();
+  if (!CAT)
+    return nullptr;
+
+  const auto *CountDRE = cast<DeclRefExpr>(CAT->getCountExpr());
+  const auto *CountDecl = CountDRE->getDecl();
+  if (const auto *IFD = dyn_cast<IndirectFieldDecl>(CountDecl))
+    CountDecl = IFD->getAnonField();
+
+  return dyn_cast<FieldDecl>(CountDecl);
+}
+
 //===----------------------------------------------------------------------===//
 // TagDecl Implementation
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 95dd32eeee10cd..c4ec7af63da12c 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -989,7 +989,7 @@ CodeGenFunction::emitFlexibleArrayMemberSize(const Expr *E, unsigned Type,
     // attribute.
     return nullptr;
 
-  const FieldDecl *CountedByFD = FindCountedByField(FAMDecl);
+  const FieldDecl *CountedByFD = FAMDecl->FindCountedByField();
   if (!CountedByFD)
     // Can't find the field referenced by the "counted_by" attribute.
     return nullptr;
@@ -3574,7 +3574,7 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
       if (!ME->HasSideEffects(getContext()) && IsFlexibleArrayMember &&
           ME->getMemberDecl()->getType()->isCountAttributedType()) {
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
-        if (const FieldDecl *CountFD = FindCountedByField(FAMDecl))
+        if (const FieldDecl *CountFD = FAMDecl->FindCountedByField())
           Result = GetCountedByFieldExprGEP(ME, FAMDecl, CountFD);
       }
     }
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 55cd95c08e3ffc..375bc911e0df1f 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1154,22 +1154,6 @@ llvm::Value *CodeGenFunction::EmitLoadOfCountedByField(
   return nullptr;
 }
 
-const FieldDecl *CodeGenFunction::FindCountedByField(const FieldDecl *FD) {
-  if (!FD)
-    return nullptr;
-
-  const auto *CAT = FD->getType()->getAs<CountAttributedType>();
-  if (!CAT)
-    return nullptr;
-
-  const auto *CountDRE = cast<DeclRefExpr>(CAT->getCountExpr());
-  const auto *CountDecl = CountDRE->getDecl();
-  if (const auto *IFD = dyn_cast<IndirectFieldDecl>(CountDecl))
-    CountDecl = IFD->getAnonField();
-
-  return dyn_cast<FieldDecl>(CountDecl);
-}
-
 void CodeGenFunction::EmitBoundsCheck(const Expr *E, const Expr *Base,
                                       llvm::Value *Index, QualType IndexType,
                                       bool Accessed) {
@@ -4309,7 +4293,7 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
           ME->isFlexibleArrayMemberLike(getContext(), StrictFlexArraysLevel) &&
           ME->getMemberDecl()->getType()->isCountAttributedType()) {
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
-        if (const FieldDecl *CountFD = FindCountedByField(FAMDecl)) {
+        if (const FieldDecl *CountFD = FAMDecl->FindCountedByField()) {
           if (std::optional<int64_t> Diff =
                   getOffsetDifferenceInBits(*this, CountFD, FAMDecl)) {
             CharUnits OffsetDiff = CGM.getContext().toCharUnitsFromBits(*Diff);
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index e5f5b94bba54b0..dadf5838d35554 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3305,10 +3305,6 @@ class CodeGenFunction : public CodeGenTypeCache {
                                         const FieldDecl *FAMDecl,
                                         uint64_t &Offset);
 
-  /// Find the FieldDecl specified in a FAM's "counted_by" attribute. Returns
-  /// \p nullptr if either the attribute or the field doesn't exist.
-  const FieldDecl *FindCountedByField(const FieldDecl *FD);
-
   llvm::Value *GetCountedByFieldExprGEP(const Expr *Base,
                                         const FieldDecl *FAMDecl,
                                         const FieldDecl *CountDecl);
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 45f540e417cd43..202ce7abc10991 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -6390,26 +6390,9 @@ ExprResult Sema::ActOnCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
       currentEvaluationContext().ReferenceToConsteval.erase(DRE);
     }
   }
-
   return Call;
 }
 
-const FieldDecl *FindCountedByField(const FieldDecl *FD) {
-  if (!FD)
-    return nullptr;
-
-  const auto *CAT = FD->getType()->getAs<CountAttributedType>();
-  if (!CAT)
-    return nullptr;
-
-  const auto *CountDRE = cast<DeclRefExpr>(CAT->getCountExpr());
-  const auto *CountDecl = CountDRE->getDecl();
-  if (const auto *IFD = dyn_cast<IndirectFieldDecl>(CountDecl))
-    CountDecl = IFD->getAnonField();
-
-  return dyn_cast<FieldDecl>(CountDecl);
-}
-
 ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
                                MultiExprArg ArgExprs, SourceLocation RParenLoc,
                                Expr *ExecConfig, bool IsExecConfig,
@@ -6620,7 +6603,7 @@ ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
       if (!ME->HasSideEffects(Context) && IsFlexibleArrayMember &&
           ME->getMemberDecl()->getType()->isCountAttributedType()) {
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
-        if (const FieldDecl *CountFD = FindCountedByField(FAMDecl)) {
+        if (const FieldDecl *CountFD = FAMDecl->FindCountedByField()) {
           // The builtin returns a 'size_t *', however 'size_t' might not be
           // the type of the count field. Thus we create an explicit c-style
           // cast to ensure the proper types going forward.

>From 7c1facf8378ad3733c0ec6b74cc17e4429c64cda Mon Sep 17 00:00:00 2001
From: Bill Wendling <morbo at google.com>
Date: Wed, 14 Aug 2024 12:11:27 -0700
Subject: [PATCH 11/11] Use CustomTypeChecking for the builtin. It allows us
 explicitly to set the return type based on the 'count' field's type.

---
 clang/include/clang/Basic/Builtins.td         |  4 +-
 .../clang/Basic/DiagnosticSemaKinds.td        |  5 +++
 clang/include/clang/Sema/Sema.h               |  2 +
 clang/lib/AST/Expr.cpp                        |  5 ---
 clang/lib/Sema/SemaChecking.cpp               | 42 +++++++++++++++++++
 clang/lib/Sema/SemaExpr.cpp                   | 29 +------------
 6 files changed, 53 insertions(+), 34 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 254cd157d5f9d0..33f26ebab7f392 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4777,6 +4777,6 @@ def ArithmeticFence : LangBuiltin<"ALL_LANGUAGES"> {
 
 def GetCountedBy : Builtin {
   let Spellings = ["__builtin_get_counted_by"];
-  let Attributes = [NoThrow];
-  let Prototype = "size_t*(void*)";
+  let Attributes = [NoThrow, CustomTypeChecking];
+  let Prototype = "int(...)";
 }
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 5cdf36660b2a66..e72d5eaf742828 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -6624,6 +6624,11 @@ def warn_counted_by_attr_elt_type_unknown_size :
   Warning<err_counted_by_attr_pointee_unknown_size.Summary>,
   InGroup<BoundsSafetyCountedByEltTyUnknownSize>;
 
+def err_builtin_get_counted_by_has_side_effects : Error<
+  "__builtin_get_counted_by cannot have side-effects">;
+def err_builtin_get_counted_by_must_be_pointer : Error<
+  "__builtin_get_counted_by argument must be a pointer">;
+
 let CategoryName = "ARC Semantic Issue" in {
 
 // ARC-mode diagnostics.
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index b7bd6c2433efd6..bdf27ea40ef318 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2545,6 +2545,8 @@ class Sema final : public SemaBase {
 
   bool BuiltinNonDeterministicValue(CallExpr *TheCall);
 
+  bool BuiltinGetCountedBy(CallExpr *TheCall);
+
   // Matrix builtin handling.
   ExprResult BuiltinMatrixTranspose(CallExpr *TheCall, ExprResult CallResult);
   ExprResult BuiltinMatrixColumnMajorLoad(CallExpr *TheCall,
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index 3a5599fac1d156..1bbff09f1d55ec 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -235,11 +235,6 @@ class MemberExprVisitor
   //                            Visitor Methods
   //===--------------------------------------------------------------------===//
 
-  const Expr *Visit(const Expr *E) {
-    return ConstStmtVisitor<MemberExprVisitor, const Expr *>::Visit(E);
-  }
-  const Expr *VisitStmt(const Stmt *S) { return nullptr; }
-
   const Expr *VisitMemberExpr(const MemberExpr *E) { return E; }
 
   const Expr *VisitArraySubscriptExpr(const ArraySubscriptExpr *E) {
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index ee143381cf4f79..1d8a6acce5d916 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2911,6 +2911,10 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
     }
     break;
   }
+  case Builtin::BI__builtin_get_counted_by:
+    if (BuiltinGetCountedBy(TheCall))
+      return ExprError();
+    break;
   }
 
   if (getLangOpts().HLSL && HLSL().CheckBuiltinFunctionCall(BuiltinID, TheCall))
@@ -5503,6 +5507,44 @@ bool Sema::BuiltinSetjmp(CallExpr *TheCall) {
   return false;
 }
 
+bool Sema::BuiltinGetCountedBy(CallExpr *TheCall) {
+  if (checkArgCount(TheCall, 1))
+    return true;
+
+  ExprResult ArgRes = UsualUnaryConversions(TheCall->getArg(0));
+  if (ArgRes.isInvalid())
+    return true;
+
+  const Expr *Arg = ArgRes.get();
+  if (!isa<PointerType>(Arg->getType()))
+    return Diag(Arg->getBeginLoc(),
+                diag::err_builtin_get_counted_by_must_be_pointer)
+           << Arg->getSourceRange();
+
+  if (Arg->HasSideEffects(Context))
+    return Diag(Arg->getBeginLoc(),
+                diag::err_builtin_get_counted_by_has_side_effects)
+           << Arg->getSourceRange();
+
+  // Use 'void *' as the default return type. If the argument doesn't have the
+  // 'counted_by' attribute, it'll return a "nullptr."
+  TheCall->setType(Context.VoidPtrTy);
+
+  if (const MemberExpr *ME = Arg->getMemberExpr();
+      ME &&
+      ME->isFlexibleArrayMemberLike(Context,
+                                    getLangOpts().getStrictFlexArraysLevel()) &&
+      ME->getMemberDecl()->getType()->isCountAttributedType()) {
+    if (const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl()))
+      if (const FieldDecl *CountFD = FAMDecl->FindCountedByField())
+        // The proper return type should be a pointer to the type of the
+        // counted_by's 'count' field.
+        TheCall->setType(Context.getPointerType(CountFD->getType()));
+  }
+
+  return false;
+}
+
 namespace {
 
 class UncoveredArgHandler {
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 202ce7abc10991..14e96d5ef4cd11 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -6591,33 +6591,8 @@ ExprResult Sema::BuildCallExpr(Scope *Scope, Expr *Fn, SourceLocation LParenLoc,
                             VK_PRValue, RParenLoc, CurFPFeatureOverrides());
   }
 
-  Result = BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc,
-                                 ExecConfig, IsExecConfig);
-
-  if (FunctionDecl *FDecl = dyn_cast_or_null<FunctionDecl>(NDecl);
-      FDecl && FDecl->getBuiltinID() == Builtin::BI__builtin_get_counted_by) {
-    if (const MemberExpr *ME = ArgExprs[0]->getMemberExpr()) {
-      bool IsFlexibleArrayMember = ME->isFlexibleArrayMemberLike(
-          Context, getLangOpts().getStrictFlexArraysLevel());
-
-      if (!ME->HasSideEffects(Context) && IsFlexibleArrayMember &&
-          ME->getMemberDecl()->getType()->isCountAttributedType()) {
-        const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
-        if (const FieldDecl *CountFD = FAMDecl->FindCountedByField()) {
-          // The builtin returns a 'size_t *', however 'size_t' might not be
-          // the type of the count field. Thus we create an explicit c-style
-          // cast to ensure the proper types going forward.
-          QualType PtrTy = Context.getPointerType(CountFD->getType());
-          Result = CStyleCastExpr::Create(
-              Context, PtrTy, VK_LValue, CK_BitCast, Result.get(), nullptr,
-              FPOptionsOverride(), Context.CreateTypeSourceInfo(PtrTy),
-              LParenLoc, RParenLoc);
-        }
-      }
-    }
-  }
-
-  return Result;
+  return BuildResolvedCallExpr(Fn, NDecl, LParenLoc, ArgExprs, RParenLoc,
+                               ExecConfig, IsExecConfig);
 }
 
 Expr *Sema::BuildBuiltinCallExpr(SourceLocation Loc, Builtin::ID Id,



More information about the cfe-commits mailing list