[llvm] [clang-tools-extra] [clang] Turn 'counted_by' into a type attribute and parse it into 'CountAttributedType' (PR #78000)

Yeoul Na via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 1 11:11:27 PST 2024


https://github.com/rapidsna updated https://github.com/llvm/llvm-project/pull/78000

>From 1a17c254ddf09cd4faf5217b2f72da3f44622f8a Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Mon, 18 Dec 2023 10:58:16 +0900
Subject: [PATCH 01/12] [BoundsSafety] Introduce CountAttributedType

CountAttributedType is a sugar type to represent a type with
a 'counted_by' attribute and the likes, which provides bounds
information to the underlying type. The type contains an
the argument of attribute as an expression. Additionally, the type
holds metadata about declarations referenced by the expression in
order to make it easier for Sema to access declarations on which
the type depends.
---
 clang/include/clang/AST/ASTContext.h          |   7 +
 clang/include/clang/AST/PropertiesBase.td     |   1 +
 clang/include/clang/AST/RecursiveASTVisitor.h |   9 ++
 clang/include/clang/AST/Type.h                | 152 ++++++++++++++++++
 clang/include/clang/AST/TypeLoc.h             |  26 +++
 clang/include/clang/AST/TypeProperties.td     |  19 +++
 clang/include/clang/Basic/TypeNodes.td        |   2 +
 .../clang/Serialization/ASTRecordReader.h     |   2 +
 .../clang/Serialization/ASTRecordWriter.h     |   5 +
 .../clang/Serialization/TypeBitCodes.def      |   1 +
 clang/lib/AST/ASTContext.cpp                  |  56 +++++++
 clang/lib/AST/ASTImporter.cpp                 |  12 ++
 clang/lib/AST/ASTStructuralEquivalence.cpp    |   7 +
 clang/lib/AST/ItaniumMangle.cpp               |   1 +
 clang/lib/AST/Type.cpp                        |  57 +++++++
 clang/lib/AST/TypeLoc.cpp                     |   4 +
 clang/lib/AST/TypePrinter.cpp                 |  23 +++
 clang/lib/CodeGen/CGDebugInfo.cpp             |   1 +
 clang/lib/CodeGen/CodeGenFunction.cpp         |   1 +
 clang/lib/Sema/SemaExpr.cpp                   |   1 +
 clang/lib/Sema/TreeTransform.h                |   7 +
 clang/lib/Serialization/ASTReader.cpp         |   8 +
 clang/lib/Serialization/ASTWriter.cpp         |   4 +
 clang/tools/libclang/CIndex.cpp               |   4 +
 24 files changed, 410 insertions(+)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index 3e46a5da3fc04..c8354fbb108a2 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -247,6 +247,8 @@ class ASTContext : public RefCountedBase<ASTContext> {
       DependentBitIntTypes;
   llvm::FoldingSet<BTFTagAttributedType> BTFTagAttributedTypes;
 
+  mutable llvm::FoldingSet<CountAttributedType> CountAttributedTypes;
+
   mutable llvm::FoldingSet<QualifiedTemplateName> QualifiedTemplateNames;
   mutable llvm::FoldingSet<DependentTemplateName> DependentTemplateNames;
   mutable llvm::FoldingSet<SubstTemplateTemplateParmStorage>
@@ -1338,6 +1340,11 @@ class ASTContext : public RefCountedBase<ASTContext> {
     return CanQualType::CreateUnsafe(getPointerType((QualType) T));
   }
 
+  QualType
+  getCountAttributedType(QualType T, Expr *CountExpr, bool CountInBytes,
+                         bool OrNull,
+                         ArrayRef<TypeCoupledDeclRefInfo> DependentDecls) const;
+
   /// Return the uniqued reference to a type adjusted from the original
   /// type to a new type.
   QualType getAdjustedType(QualType Orig, QualType New) const;
diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td
index d86c4eba6a225..25ddfd105ab87 100644
--- a/clang/include/clang/AST/PropertiesBase.td
+++ b/clang/include/clang/AST/PropertiesBase.td
@@ -143,6 +143,7 @@ def UInt32 : CountPropertyType<"uint32_t">;
 def UInt64 : CountPropertyType<"uint64_t">;
 def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">;
 def VectorKind : EnumPropertyType<"VectorKind">;
+def TypeCoupledDeclRefInfo : PropertyType;
 
 def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> {
   let BufferElementTypes = [ QualType ];
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 8f2714e142bbe..b57ab36939d07 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -1099,6 +1099,12 @@ DEF_TRAVERSE_TYPE(InjectedClassNameType, {})
 DEF_TRAVERSE_TYPE(AttributedType,
                   { TRY_TO(TraverseType(T->getModifiedType())); })
 
+DEF_TRAVERSE_TYPE(CountAttributedType, {
+  if (T->getCountExpr())
+    TRY_TO(TraverseStmt(T->getCountExpr()));
+  TRY_TO(TraverseType(T->desugar()));
+})
+
 DEF_TRAVERSE_TYPE(BTFTagAttributedType,
                   { TRY_TO(TraverseType(T->getWrappedType())); })
 
@@ -1385,6 +1391,9 @@ DEF_TRAVERSE_TYPELOC(MacroQualifiedType,
 DEF_TRAVERSE_TYPELOC(AttributedType,
                      { TRY_TO(TraverseTypeLoc(TL.getModifiedLoc())); })
 
+DEF_TRAVERSE_TYPELOC(CountAttributedType,
+                     { TRY_TO(TraverseTypeLoc(TL.getInnerLoc())); })
+
 DEF_TRAVERSE_TYPELOC(BTFTagAttributedType,
                      { TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); })
 
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index d4e5310fb3abc..489644ca5acf8 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -61,6 +61,7 @@ class BTFTypeTagAttr;
 class ExtQuals;
 class QualType;
 class ConceptDecl;
+class ValueDecl;
 class TagDecl;
 class TemplateParameterList;
 class Type;
@@ -2000,6 +2001,21 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     unsigned NumExpansions;
   };
 
+  class CountAttributedTypeBitfields {
+    friend class CountAttributedType;
+
+    LLVM_PREFERRED_TYPE(TypeBitfields)
+    unsigned : NumTypeBits;
+
+    /// The limit is 15.
+    static constexpr unsigned NumCoupledDeclsBits = 4;
+    unsigned NumCoupledDecls : NumCoupledDeclsBits;
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned CountInBytes : 1;
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned OrNull : 1;
+  };
+
   union {
     TypeBitfields TypeBits;
     ArrayTypeBitfields ArrayTypeBits;
@@ -2022,6 +2038,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     DependentTemplateSpecializationTypeBitfields
       DependentTemplateSpecializationTypeBits;
     PackExpansionTypeBitfields PackExpansionTypeBits;
+    CountAttributedTypeBitfields CountAttributedTypeBits;
   };
 
 private:
@@ -2719,6 +2736,14 @@ template <> const TemplateSpecializationType *Type::getAs() const;
 /// until it reaches an AttributedType or a non-sugared type.
 template <> const AttributedType *Type::getAs() const;
 
+/// This will check for a BoundsAttributedType by removing any existing
+/// sugar until it reaches an BoundsAttributedType or a non-sugared type.
+template <> const BoundsAttributedType *Type::getAs() const;
+
+/// This will check for a CountAttributedType by removing any existing
+/// sugar until it reaches an CountAttributedType or a non-sugared type.
+template <> const CountAttributedType *Type::getAs() const;
+
 // We can do canonical leaf types faster, because we don't have to
 // worry about preserving child type decoration.
 #define TYPE(Class, Base)
@@ -2917,6 +2942,133 @@ class PointerType : public Type, public llvm::FoldingSetNode {
   static bool classof(const Type *T) { return T->getTypeClass() == Pointer; }
 };
 
+/// [BoundsSafety] Represents information of declarations referenced by the
+/// arguments of the `counted_by` attribute and the likes.
+class TypeCoupledDeclRefInfo {
+public:
+  using BaseTy = llvm::PointerIntPair<ValueDecl *, 1, unsigned>;
+
+private:
+  enum {
+    DerefShift = 0,
+    DerefMask = 1,
+  };
+  BaseTy Data;
+
+public:
+  /// \p D is to a declaration referenced by the argument of attribute. \p Deref
+  /// indicates whether \p D is referenced as a dereferenced form, e.g., \p
+  /// Deref is true for `*n` in `int *__counted_by(*n)`.
+  TypeCoupledDeclRefInfo(ValueDecl *D = nullptr, bool Deref = false);
+
+  bool isDeref() const;
+  ValueDecl *getDecl() const;
+  unsigned getInt() const;
+  void *getOpaqueValue() const;
+  bool operator==(const TypeCoupledDeclRefInfo &Other) const;
+  void setFromOpaqueValue(void *V);
+};
+
+/// [BoundsSafety] Represents a parent type class for CountAttributedType and
+/// similar sugar types that will be introduced to represent a type with a
+/// bounds attribute.
+///
+/// Provides a common interface to navigate declarations referred to by the
+/// bounds expression.
+
+class BoundsAttributedType : public Type, public llvm::FoldingSetNode {
+  QualType WrappedTy;
+
+protected:
+  ArrayRef<TypeCoupledDeclRefInfo> Decls; // stored in trailing objects
+
+  BoundsAttributedType(TypeClass TC, QualType Wrapped, QualType Canon);
+
+public:
+  bool isSugared() const { return true; }
+  QualType desugar() const { return WrappedTy; }
+
+  using decl_iterator = const TypeCoupledDeclRefInfo *;
+  using decl_range = llvm::iterator_range<decl_iterator>;
+
+  decl_iterator dependent_decl_begin() const { return Decls.begin(); }
+  decl_iterator dependent_decl_end() const { return Decls.end(); }
+
+  unsigned getNumCoupledDecls() const { return Decls.size(); }
+
+  decl_range dependent_decls() const {
+    return decl_range(dependent_decl_begin(), dependent_decl_end());
+  }
+
+  ArrayRef<TypeCoupledDeclRefInfo> getCoupledDecls() const {
+    return {dependent_decl_begin(), dependent_decl_end()};
+  }
+
+  bool referencesFieldDecls() const;
+
+  static bool classof(const Type *T) {
+    switch (T->getTypeClass()) {
+    case CountAttributed:
+      return true;
+    default:
+      return false;
+    }
+  }
+};
+
+/// Represents a sugar type with `__counted_by` or `__sized_by` annotations,
+/// including their `_or_null` variants.
+class CountAttributedType final
+    : public BoundsAttributedType,
+      public llvm::TrailingObjects<CountAttributedType,
+                                   TypeCoupledDeclRefInfo> {
+  friend class ASTContext;
+
+  Expr *CountExpr;
+  /// \p CountExpr represents the argument of __counted_by or the likes. \p
+  /// CountInBytes indicates that \p CountExpr is a byte count (i.e.,
+  /// __sized_by(_or_null)) \p OrNull means it's an or_null variant (i.e.,
+  /// __counted_by_or_null or __sized_by_or_null) \p CoupledDecls contains the
+  /// list of declarations referenced by \p CountExpr, which the type depends on
+  /// for the bounds information.
+  CountAttributedType(QualType Wrapped, QualType Canon, Expr *CountExpr,
+                      bool CountInBytes, bool OrNull,
+                      ArrayRef<TypeCoupledDeclRefInfo> CoupledDecls);
+
+  unsigned numTrailingObjects(OverloadToken<TypeCoupledDeclRefInfo>) const {
+    return CountAttributedTypeBits.NumCoupledDecls;
+  }
+
+public:
+  enum DynamicCountPointerKind {
+    CountedBy = 0,
+    SizedBy,
+    CountedByOrNull,
+    SizedByOrNull,
+  };
+
+  Expr *getCountExpr() const { return CountExpr; }
+  bool isCountInBytes() const { return CountAttributedTypeBits.CountInBytes; }
+  bool isOrNull() const { return CountAttributedTypeBits.OrNull; }
+
+  DynamicCountPointerKind getKind() const {
+    if (isOrNull())
+      return isCountInBytes() ? SizedByOrNull : CountedByOrNull;
+    return isCountInBytes() ? SizedBy : CountedBy;
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, desugar(), CountExpr, isCountInBytes(), isOrNull());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, QualType WrappedTy,
+                      Expr *CountExpr, bool CountInBytes, bool Nullable);
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == CountAttributed;
+  }
+};
+
 /// Represents a type which was implicitly adjusted by the semantic
 /// engine for arbitrary reasons.  For example, array and function types can
 /// decay, and function types can have their calling conventions adjusted.
diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h
index 471deb14aba51..fa3cd46ffd3fc 100644
--- a/clang/include/clang/AST/TypeLoc.h
+++ b/clang/include/clang/AST/TypeLoc.h
@@ -1110,6 +1110,32 @@ class ObjCInterfaceTypeLoc : public ConcreteTypeLoc<ObjCObjectTypeLoc,
   }
 };
 
+struct BoundsAttributedLocInfo {};
+class BoundsAttributedTypeLoc
+    : public ConcreteTypeLoc<UnqualTypeLoc, BoundsAttributedTypeLoc,
+                             BoundsAttributedType, BoundsAttributedLocInfo> {
+public:
+  TypeLoc getInnerLoc() const { return this->getInnerTypeLoc(); }
+  QualType getInnerType() const { return this->getTypePtr()->desugar(); }
+  void initializeLocal(ASTContext &Context, SourceLocation Loc) {
+    // nothing to do
+  }
+  unsigned getLocalDataAlignment() const { return 1; }
+  unsigned getLocalDataSize() const { return 0; }
+};
+
+class CountAttributedTypeLoc final
+    : public InheritingConcreteTypeLoc<BoundsAttributedTypeLoc,
+                                       CountAttributedTypeLoc,
+                                       CountAttributedType> {
+public:
+  Expr *getCountExpr() const { return getTypePtr()->getCountExpr(); }
+  bool isCountInBytes() const { return getTypePtr()->isCountInBytes(); }
+  bool isOrNull() const { return getTypePtr()->isOrNull(); }
+
+  SourceRange getLocalSourceRange() const;
+};
+
 struct MacroQualifiedLocInfo {
   SourceLocation ExpansionLoc;
 };
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index 682c869b0c584..8697903d8d7f6 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -25,6 +25,25 @@ let Class = PointerType in {
   def : Creator<[{ return ctx.getPointerType(pointeeType); }]>;
 }
 
+let Class = CountAttributedType in {
+  def : Property<"WrappedTy", QualType> {
+    let Read = [{ node->desugar() }];
+  }
+  def : Property<"CountExpr", ExprRef> {
+    let Read = [{ node->getCountExpr() }];
+  }
+  def : Property<"CountInBytes", Bool> {
+    let Read = [{ node->isCountInBytes() }];
+  }
+  def : Property<"OrNull", Bool> {
+    let Read = [{ node->isOrNull() }];
+  }
+  def : Property<"CoupledDecls", Array<TypeCoupledDeclRefInfo>> {
+    let Read = [{ node->getCoupledDecls() }];
+  }
+  def : Creator<[{ return ctx.getCountAttributedType(WrappedTy, CountExpr, CountInBytes, OrNull, CoupledDecls); }]>;
+}
+
 let Class = AdjustedType in {
   def : Property<"originalType", QualType> {
     let Read = [{ node->getOriginalType() }];
diff --git a/clang/include/clang/Basic/TypeNodes.td b/clang/include/clang/Basic/TypeNodes.td
index 649b071cebb94..da733716fa1f6 100644
--- a/clang/include/clang/Basic/TypeNodes.td
+++ b/clang/include/clang/Basic/TypeNodes.td
@@ -107,6 +107,8 @@ def ObjCTypeParamType : TypeNode<Type>, NeverCanonical;
 def ObjCObjectType : TypeNode<Type>;
 def ObjCInterfaceType : TypeNode<ObjCObjectType>, LeafType;
 def ObjCObjectPointerType : TypeNode<Type>;
+def BoundsAttributedType : TypeNode<Type, 1>;
+def CountAttributedType : TypeNode<BoundsAttributedType>, NeverCanonical;
 def PipeType : TypeNode<Type>;
 def AtomicType : TypeNode<Type>;
 def BitIntType : TypeNode<Type>;
diff --git a/clang/include/clang/Serialization/ASTRecordReader.h b/clang/include/clang/Serialization/ASTRecordReader.h
index 80a1359fd16aa..5d3e95cb5d630 100644
--- a/clang/include/clang/Serialization/ASTRecordReader.h
+++ b/clang/include/clang/Serialization/ASTRecordReader.h
@@ -221,6 +221,8 @@ class ASTRecordReader
     return Reader->ReadSelector(*F, Record, Idx);
   }
 
+  TypeCoupledDeclRefInfo readTypeCoupledDeclRefInfo();
+
   /// Read a declaration name, advancing Idx.
   // DeclarationName readDeclarationName(); (inherited)
   DeclarationNameLoc readDeclarationNameLoc(DeclarationName Name);
diff --git a/clang/include/clang/Serialization/ASTRecordWriter.h b/clang/include/clang/Serialization/ASTRecordWriter.h
index 9a64735c9fa55..e007d4a70843a 100644
--- a/clang/include/clang/Serialization/ASTRecordWriter.h
+++ b/clang/include/clang/Serialization/ASTRecordWriter.h
@@ -141,6 +141,11 @@ class ASTRecordWriter
     AddSourceLocation(Loc);
   }
 
+  void writeTypeCoupledDeclRefInfo(TypeCoupledDeclRefInfo Info) {
+    writeDeclRef(Info.getDecl());
+    writeBool(Info.isDeref());
+  }
+
   /// Emit a source range.
   void AddSourceRange(SourceRange Range, LocSeq *Seq = nullptr) {
     return Writer->AddSourceRange(Range, *Record, Seq);
diff --git a/clang/include/clang/Serialization/TypeBitCodes.def b/clang/include/clang/Serialization/TypeBitCodes.def
index 89ae1a2fa3954..c04dacdaa1582 100644
--- a/clang/include/clang/Serialization/TypeBitCodes.def
+++ b/clang/include/clang/Serialization/TypeBitCodes.def
@@ -64,5 +64,6 @@ TYPE_BIT_CODE(ConstantMatrix, CONSTANT_MATRIX, 52)
 TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 53)
 TYPE_BIT_CODE(Using, USING, 54)
 TYPE_BIT_CODE(BTFTagAttributed, BTFTAG_ATTRIBUTED, 55)
+TYPE_BIT_CODE(CountAttributed, COUNT_ATTRIBUTED, 56)
 
 #undef TYPE_BIT_CODE
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index d9cefcaa84d7e..43d2da0e60f36 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2341,6 +2341,9 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
     return getTypeInfo(
                   cast<AttributedType>(T)->getEquivalentType().getTypePtr());
 
+  case Type::CountAttributed:
+    return getTypeInfo(cast<CountAttributedType>(T)->desugar().getTypePtr());
+
   case Type::BTFTagAttributed:
     return getTypeInfo(
         cast<BTFTagAttributedType>(T)->getWrappedType().getTypePtr());
@@ -3106,6 +3109,33 @@ QualType ASTContext::removePtrSizeAddrSpace(QualType T) const {
   return T;
 }
 
+QualType ASTContext::getCountAttributedType(
+    QualType WrappedTy, Expr *CountExpr, bool CountInBytes, bool OrNull,
+    ArrayRef<TypeCoupledDeclRefInfo> DependentDecls) const {
+  assert(WrappedTy->isPointerType() || WrappedTy->isArrayType());
+  assert(CountExpr->isPRValue());
+
+  llvm::FoldingSetNodeID ID;
+  CountAttributedType::Profile(ID, WrappedTy, CountExpr, CountInBytes, OrNull);
+
+  void *InsertPos = nullptr;
+  CountAttributedType *CATy =
+      CountAttributedTypes.FindNodeOrInsertPos(ID, InsertPos);
+  if (CATy)
+    return QualType(CATy, 0);
+
+  QualType CanonTy = getCanonicalType(WrappedTy);
+  size_t Size = CountAttributedType::totalSizeToAlloc<TypeCoupledDeclRefInfo>(
+      DependentDecls.size());
+  CATy = (CountAttributedType *)Allocate(Size, TypeAlignment);
+  new (CATy) CountAttributedType(WrappedTy, CanonTy, CountExpr, CountInBytes,
+                                 OrNull, DependentDecls);
+  Types.push_back(CATy);
+  CountAttributedTypes.InsertNode(CATy, InsertPos);
+
+  return QualType(CATy, 0);
+}
+
 const FunctionType *ASTContext::adjustFunctionType(const FunctionType *T,
                                                    FunctionType::ExtInfo Info) {
   if (T->getExtInfo() == Info)
@@ -13158,6 +13188,32 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X,
       return QualType();
     return Ctx.getUsingType(CD, Ctx.getQualifiedType(Underlying));
   }
+  case Type::CountAttributed: {
+    const auto *DX = cast<CountAttributedType>(X),
+               *DY = cast<CountAttributedType>(Y);
+    if (DX->isCountInBytes() != DY->isCountInBytes())
+      return QualType();
+    if (DX->isOrNull() != DY->isOrNull())
+      return QualType();
+    const auto CEX = DX->getCountExpr();
+    const auto CEY = DY->getCountExpr();
+    const auto CDX = DX->getCoupledDecls();
+    if (Ctx.hasSameExpr(CEX, CEY))
+      return Ctx.getCountAttributedType(Ctx.getQualifiedType(Underlying), CEX,
+                                        DX->isCountInBytes(), DX->isOrNull(),
+                                        CDX);
+    if (!CEX->isIntegerConstantExpr(Ctx) || !CEY->isIntegerConstantExpr(Ctx))
+      return QualType();
+    // Two declarations with the same integer constant may still differ in their
+    // expression pointers, so we need to evaluate them.
+    const auto VX = CEX->getIntegerConstantExpr(Ctx);
+    const auto VY = CEY->getIntegerConstantExpr(Ctx);
+    if (*VX != *VY)
+      return QualType();
+    return Ctx.getCountAttributedType(Ctx.getQualifiedType(Underlying), CEX,
+                                      DX->isCountInBytes(), DX->isOrNull(),
+                                      CDX);
+  }
   }
   llvm_unreachable("Unhandled Type Class");
 }
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index b762d6a4cd380..6bad0b701cea1 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1504,6 +1504,18 @@ ExpectedType ASTNodeImporter::VisitAttributedType(const AttributedType *T) {
       *ToModifiedTypeOrErr, *ToEquivalentTypeOrErr);
 }
 
+ExpectedType
+ASTNodeImporter::VisitCountAttributedType(const CountAttributedType *T) {
+  ExpectedType ToWrappedTypeOrErr = import(T->desugar());
+  if (!ToWrappedTypeOrErr)
+    return ToWrappedTypeOrErr.takeError();
+
+  // FIXME: Handle CoupledDecls correctly
+  return Importer.getToContext().getCountAttributedType(
+      *ToWrappedTypeOrErr, T->getCountExpr(), T->isCountInBytes(),
+      T->isOrNull(), T->getCoupledDecls());
+}
+
 ExpectedType ASTNodeImporter::VisitTemplateTypeParmType(
     const TemplateTypeParmType *T) {
   Expected<TemplateTypeParmDecl *> ToDeclOrErr = import(T->getDecl());
diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp
index a9e0d1698a917..c0e731af0ac88 100644
--- a/clang/lib/AST/ASTStructuralEquivalence.cpp
+++ b/clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -1008,6 +1008,13 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
       return false;
     break;
 
+  case Type::CountAttributed:
+    if (!IsStructurallyEquivalent(Context,
+                                  cast<CountAttributedType>(T1)->desugar(),
+                                  cast<CountAttributedType>(T2)->desugar()))
+      return false;
+    break;
+
   case Type::BTFTagAttributed:
     if (!IsStructurallyEquivalent(
             Context, cast<BTFTagAttributedType>(T1)->getWrappedType(),
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index b1678479888eb..d2464d9121f1a 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -2431,6 +2431,7 @@ bool CXXNameMangler::mangleUnresolvedTypeOrSimpleId(QualType Ty,
   case Type::MacroQualified:
   case Type::BitInt:
   case Type::DependentBitInt:
+  case Type::CountAttributed:
     llvm_unreachable("type is illegal as a nested name specifier");
 
   case Type::SubstTemplateTypeParmPack:
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index b419fc8836b03..3c11b4be12559 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -385,6 +385,26 @@ void DependentBitIntType::Profile(llvm::FoldingSetNodeID &ID,
   NumBitsExpr->Profile(ID, Context, true);
 }
 
+bool BoundsAttributedType::referencesFieldDecls() const {
+  for (const auto &Decl : dependent_decls())
+    if (isa<FieldDecl>(Decl.getDecl()))
+      return true;
+  return false;
+}
+
+void CountAttributedType::Profile(llvm::FoldingSetNodeID &ID,
+                                  QualType WrappedTy, Expr *CountExpr,
+                                  bool CountInBytes, bool OrNull) {
+  ID.AddPointer(WrappedTy.getAsOpaquePtr());
+  ID.AddBoolean(CountInBytes);
+  ID.AddBoolean(OrNull);
+  // We profile it as a pointer as the StmtProfiler considers parameter
+  // expressions on function declaration and function definition as the
+  // same, resulting in count expression being evaluated with ParamDecl
+  // not in the function scope.
+  ID.AddPointer(CountExpr);
+}
+
 /// getArrayElementTypeNoTypeQual - If this is an array type, return the
 /// element type of the array, potentially with type qualifiers missing.
 /// This method should never be used when type qualifiers are meaningful.
@@ -3685,6 +3705,43 @@ void FunctionProtoType::Profile(llvm::FoldingSetNodeID &ID,
           getExtProtoInfo(), Ctx, isCanonicalUnqualified());
 }
 
+TypeCoupledDeclRefInfo::TypeCoupledDeclRefInfo(ValueDecl *D, bool Deref)
+    : Data(D, Deref << DerefShift) {}
+
+bool TypeCoupledDeclRefInfo::isDeref() const {
+  return Data.getInt() & DerefMask;
+}
+ValueDecl *TypeCoupledDeclRefInfo::getDecl() const { return Data.getPointer(); }
+unsigned TypeCoupledDeclRefInfo::getInt() const { return Data.getInt(); }
+void *TypeCoupledDeclRefInfo::getOpaqueValue() const {
+  return Data.getOpaqueValue();
+}
+bool TypeCoupledDeclRefInfo::operator==(
+    const TypeCoupledDeclRefInfo &Other) const {
+  return getOpaqueValue() == Other.getOpaqueValue();
+}
+void TypeCoupledDeclRefInfo::setFromOpaqueValue(void *V) {
+  Data.setFromOpaqueValue(V);
+}
+
+BoundsAttributedType::BoundsAttributedType(TypeClass TC, QualType Wrapped,
+                                           QualType Canon)
+    : Type(TC, Canon, Wrapped->getDependence()), WrappedTy(Wrapped) {}
+
+CountAttributedType::CountAttributedType(
+    QualType Wrapped, QualType Canon, Expr *CountExpr, bool CountInBytes,
+    bool OrNull, ArrayRef<TypeCoupledDeclRefInfo> CoupledDecls)
+    : BoundsAttributedType(CountAttributed, Wrapped, Canon),
+      CountExpr(CountExpr) {
+  CountAttributedTypeBits.NumCoupledDecls = CoupledDecls.size();
+  CountAttributedTypeBits.CountInBytes = CountInBytes;
+  CountAttributedTypeBits.OrNull = OrNull;
+  auto *DeclSlot = getTrailingObjects<TypeCoupledDeclRefInfo>();
+  Decls = llvm::ArrayRef(DeclSlot, CoupledDecls.size());
+  for (unsigned i = 0; i != CoupledDecls.size(); ++i)
+    DeclSlot[i] = CoupledDecls[i];
+}
+
 TypedefType::TypedefType(TypeClass tc, const TypedefNameDecl *D,
                          QualType Underlying, QualType can)
     : Type(tc, can, toSemanticDependence(can->getDependence())),
diff --git a/clang/lib/AST/TypeLoc.cpp b/clang/lib/AST/TypeLoc.cpp
index e12b9b50f6e72..c18e2e0b45af6 100644
--- a/clang/lib/AST/TypeLoc.cpp
+++ b/clang/lib/AST/TypeLoc.cpp
@@ -516,6 +516,10 @@ SourceRange AttributedTypeLoc::getLocalSourceRange() const {
   return getAttr() ? getAttr()->getRange() : SourceRange();
 }
 
+SourceRange CountAttributedTypeLoc::getLocalSourceRange() const {
+  return getCountExpr() ? getCountExpr()->getSourceRange() : SourceRange();
+}
+
 SourceRange BTFTagAttributedTypeLoc::getLocalSourceRange() const {
   return getAttr() ? getAttr()->getRange() : SourceRange();
 }
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index f694124292736..ab4d74385c1d8 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -286,6 +286,7 @@ bool TypePrinter::canPrefixQualifiers(const Type *T,
     case Type::PackExpansion:
     case Type::SubstTemplateTypeParm:
     case Type::MacroQualified:
+    case Type::CountAttributed:
       CanPrefixQualifiers = false;
       break;
 
@@ -1682,6 +1683,28 @@ void TypePrinter::printPackExpansionAfter(const PackExpansionType *T,
   OS << "...";
 }
 
+void TypePrinter::printCountAttributedBefore(const CountAttributedType *T,
+                                             raw_ostream &OS) {
+  printBefore(T->desugar(), OS);
+
+  if (T->isCountInBytes() && T->isOrNull())
+    OS << " __sized_by_or_null(";
+  else if (T->isCountInBytes())
+    OS << " __sized_by(";
+  else if (T->isOrNull())
+    OS << " __counted_by_or_null(";
+  else
+    OS << " __counted_by(";
+  if (T->getCountExpr())
+    T->getCountExpr()->printPretty(OS, nullptr, Policy);
+  OS << ')';
+}
+
+void TypePrinter::printCountAttributedAfter(const CountAttributedType *T,
+                                            raw_ostream &OS) {
+  printAfter(T->desugar(), OS);
+}
+
 void TypePrinter::printAttributedBefore(const AttributedType *T,
                                         raw_ostream &OS) {
   // FIXME: Generate this with TableGen.
diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp
index 236d53bee4e8f..653e04a5bc870 100644
--- a/clang/lib/CodeGen/CGDebugInfo.cpp
+++ b/clang/lib/CodeGen/CGDebugInfo.cpp
@@ -3644,6 +3644,7 @@ llvm::DIType *CGDebugInfo::CreateTypeNode(QualType Ty, llvm::DIFile *Unit) {
   case Type::TemplateSpecialization:
     return CreateType(cast<TemplateSpecializationType>(Ty), Unit);
 
+  case Type::CountAttributed:
   case Type::Auto:
   case Type::Attributed:
   case Type::BTFTagAttributed:
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index 2673e4a5cee7b..730632d1ed8eb 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -2399,6 +2399,7 @@ void CodeGenFunction::EmitVariablyModifiedType(QualType type) {
     case Type::BTFTagAttributed:
     case Type::SubstTemplateTypeParm:
     case Type::MacroQualified:
+    case Type::CountAttributed:
       // Keep walking after single level desugaring.
       type = type.getSingleStepDesugaredType(getContext());
       break;
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 2f48ea237cdfa..ca8c26afb48cd 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -4706,6 +4706,7 @@ static void captureVariablyModifiedType(ASTContext &Context, QualType T,
     case Type::BTFTagAttributed:
     case Type::SubstTemplateTypeParm:
     case Type::MacroQualified:
+    case Type::CountAttributed:
       // Keep walking after single level desugaring.
       T = T.getSingleStepDesugaredType(Context);
       break;
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 1a1bc87d2b320..27897bb87fb9b 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -7119,6 +7119,13 @@ QualType TreeTransform<Derived>::TransformAttributedType(TypeLocBuilder &TLB,
       });
 }
 
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformCountAttributedType(
+    TypeLocBuilder &TLB, CountAttributedTypeLoc TL) {
+  // TODO
+  llvm_unreachable("Unexpected TreeTransform for CountAttributedType");
+}
+
 template <typename Derived>
 QualType TreeTransform<Derived>::TransformBTFTagAttributedType(
     TypeLocBuilder &TLB, BTFTagAttributedTypeLoc TL) {
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 287f9a0300be5..9f1c06f542a66 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -6990,6 +6990,10 @@ void TypeLocReader::VisitAttributedTypeLoc(AttributedTypeLoc TL) {
   TL.setAttr(ReadAttr());
 }
 
+void TypeLocReader::VisitCountAttributedTypeLoc(CountAttributedTypeLoc TL) {
+  // nothing to do
+}
+
 void TypeLocReader::VisitBTFTagAttributedTypeLoc(BTFTagAttributedTypeLoc TL) {
   // Nothing to do.
 }
@@ -9110,6 +9114,10 @@ DeclarationNameInfo ASTRecordReader::readDeclarationNameInfo() {
   return NameInfo;
 }
 
+TypeCoupledDeclRefInfo ASTRecordReader::readTypeCoupledDeclRefInfo() {
+  return TypeCoupledDeclRefInfo(readDeclAs<ValueDecl>(), readBool());
+}
+
 void ASTRecordReader::readQualifierInfo(QualifierInfo &Info) {
   Info.QualifierLoc = readNestedNameSpecifierLoc();
   unsigned NumTPLists = readInt();
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 9950fa9c08faa..5b37146948326 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -510,6 +510,10 @@ void TypeLocWriter::VisitAttributedTypeLoc(AttributedTypeLoc TL) {
   Record.AddAttr(TL.getAttr());
 }
 
+void TypeLocWriter::VisitCountAttributedTypeLoc(CountAttributedTypeLoc TL) {
+  // nothing to do
+}
+
 void TypeLocWriter::VisitBTFTagAttributedTypeLoc(BTFTagAttributedTypeLoc TL) {
   // Nothing to do.
 }
diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 841522a0f4788..16a1469d59f66 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -1773,6 +1773,10 @@ bool CursorVisitor::VisitAttributedTypeLoc(AttributedTypeLoc TL) {
   return Visit(TL.getModifiedLoc());
 }
 
+bool CursorVisitor::VisitCountAttributedTypeLoc(CountAttributedTypeLoc TL) {
+  return Visit(TL.getInnerLoc());
+}
+
 bool CursorVisitor::VisitBTFTagAttributedTypeLoc(BTFTagAttributedTypeLoc TL) {
   return Visit(TL.getWrappedLoc());
 }

>From e65a99a00faba70960e75b4b8edb5acecb675197 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Wed, 20 Dec 2023 16:09:17 +0900
Subject: [PATCH 02/12] Adjust CountedBy attribute definition and parse
 CountAttributedType

---
 clang/include/clang/Basic/Attr.td             |   7 +-
 .../clang/Basic/DiagnosticSemaKinds.td        |  14 +-
 clang/include/clang/Parse/Parser.h            |   7 +
 clang/include/clang/Sema/Sema.h               |  12 +-
 clang/lib/AST/Type.cpp                        |   8 +
 clang/lib/CodeGen/CGExpr.cpp                  |  35 +---
 clang/lib/Parse/ParseDecl.cpp                 |  50 +++++
 clang/lib/Sema/SemaDecl.cpp                   |   6 -
 clang/lib/Sema/SemaDeclAttr.cpp               | 173 +++++++-----------
 clang/lib/Sema/SemaExpr.cpp                   |  22 ++-
 clang/lib/Sema/SemaType.cpp                   |  11 ++
 11 files changed, 198 insertions(+), 147 deletions(-)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index a03b0e44e15f7..af2680be0da90 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4373,10 +4373,11 @@ def CodeAlign: StmtAttr {
   }];
 }
 
-def CountedBy : InheritableAttr {
+def CountedBy : DeclOrTypeAttr {
   let Spellings = [Clang<"counted_by">];
-  let Subjects = SubjectList<[Field]>;
-  let Args = [IdentifierArgument<"CountedByField">];
+  let Subjects = SubjectList<[Field], ErrorDiag>;
+  let Args = [ExprArgument<"Count">, IntArgument<"NestedLevel">];
+  let ParseArgumentsAsUnevaluated = 1;
   let Documentation = [CountedByDocs];
   let LangOpts = [COnly];
   // FIXME: This is ugly. Let using a DeclArgument would be nice, but a Decl
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 1a79892e40030..ae8b90964c3fc 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -6445,12 +6445,18 @@ def err_flexible_array_count_not_in_same_struct : Error<
   "'counted_by' field %0 isn't within the same struct as the flexible array">;
 def err_counted_by_attr_not_on_flexible_array_member : Error<
   "'counted_by' only applies to C99 flexible array members">;
-def err_counted_by_attr_refers_to_flexible_array : Error<
-  "'counted_by' cannot refer to the flexible array %0">;
+def err_counted_by_attr_refer_to_itself : Error<
+  "'counted_by' cannot refer to the flexible array member %0">;
 def err_counted_by_must_be_in_structure : Error<
   "field %0 in 'counted_by' not inside structure">;
-def err_flexible_array_counted_by_attr_field_not_integer : Error<
-  "field %0 in 'counted_by' must be a non-boolean integer type">;
+def err_counted_by_attr_argument_not_integer : Error<
+  "'counted_by' requires a non-boolean integer type argument">;
+def err_counted_by_attr_only_support_simple_decl_reference : Error<
+  "'counted_by' argument must be a simple declaration reference">;
+def err_counted_by_attr_in_union : Error<
+  "'counted_by' cannot be applied to a union member">;
+def err_counted_by_attr_refer_to_union : Error<
+  "'counted_by' argument cannot refer to a union member">;
 def note_flexible_array_counted_by_attr_field : Note<
   "field %0 declared here">;
 
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index e50a4d05b4599..f88191ef5f52c 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3054,6 +3054,13 @@ class Parser : public CodeCompletionHandler {
                                  SourceLocation ScopeLoc,
                                  ParsedAttr::Form Form);
 
+  void ParseBoundsAttribute(IdentifierInfo &AttrName,
+                            SourceLocation AttrNameLoc,
+                            ParsedAttributes &Attrs,
+                            IdentifierInfo *ScopeName,
+                            SourceLocation ScopeLoc,
+                            ParsedAttr::Form Form);
+
   void ParseTypeofSpecifier(DeclSpec &DS);
   SourceLocation ParseDecltypeSpecifier(DeclSpec &DS);
   void AnnotateExistingDecltypeSpecifier(const DeclSpec &DS,
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index cf2d4fbe6d3ba..027afb5d703d1 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -1341,7 +1341,7 @@ class Sema final {
     /// \brief Describes whether we are in an expression constext which we have
     /// to handle differently.
     enum ExpressionKind {
-      EK_Decltype, EK_TemplateArgument, EK_Other
+      EK_Decltype, EK_TemplateArgument, EK_BoundsAttrArgument, EK_Other
     } ExprContext;
 
     // A context can be nested in both a discarded statement context and
@@ -1436,6 +1436,11 @@ class Sema final {
   std::tuple<MangleNumberingContext *, Decl *>
   getCurrentMangleNumberContext(const DeclContext *DC);
 
+  bool isBoundsAttrArgument() const {
+    return ExprEvalContexts.back().ExprContext == 
+        ExpressionEvaluationContextRecord::ExpressionKind::EK_BoundsAttrArgument; 
+  }
+
 
   /// SpecialMemberOverloadResult - The overloading result for a special member
   /// function.
@@ -2612,6 +2617,9 @@ class Sema final {
   QualType BuiltinChangeSignedness(QualType BaseType, UTTKind UKind,
                                    SourceLocation Loc);
 
+  QualType BuildCountAttributedArrayType(QualType WrappedTy, Expr *CountExpr,
+                              const llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls);
+
   //===--------------------------------------------------------------------===//
   // Symbol table / Decl tracking callbacks: SemaDecl.cpp.
   //
@@ -4799,8 +4807,6 @@ class Sema final {
   bool CheckAlwaysInlineAttr(const Stmt *OrigSt, const Stmt *CurSt,
                              const AttributeCommonInfo &A);
 
-  bool CheckCountedByAttr(Scope *Scope, const FieldDecl *FD);
-
   /// Adjust the calling convention of a method to be the ABI default if it
   /// wasn't specified explicitly.  This handles method types formed from
   /// function type typedefs and typename template arguments.
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 3c11b4be12559..01d5002a90c89 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -579,6 +579,14 @@ template <> const AttributedType *Type::getAs() const {
   return getAsSugar<AttributedType>(this);
 }
 
+template <> const BoundsAttributedType *Type::getAs() const {
+  return getAsSugar<BoundsAttributedType>(this);
+}
+
+template <> const CountAttributedType *Type::getAs() const {
+  return getAsSugar<CountAttributedType>(this);
+}
+
 /// getUnqualifiedDesugaredType - Pull any qualifiers and syntactic
 /// sugar off the given type.  This should produce an object of the
 /// same dynamic type as the canonical type.
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index d12e85b48d0b0..9cf0e8984de08 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1136,38 +1136,19 @@ llvm::Value *CodeGenFunction::EmitCountedByFieldExpr(
 }
 
 const FieldDecl *CodeGenFunction::FindCountedByField(const FieldDecl *FD) {
-  if (!FD || !FD->hasAttr<CountedByAttr>())
+  if (!FD)
     return nullptr;
 
-  const auto *CBA = FD->getAttr<CountedByAttr>();
-  if (!CBA)
+  auto *CAT = FD->getType()->getAs<CountAttributedType>();
+  if (!CAT)
     return nullptr;
 
-  auto GetNonAnonStructOrUnion =
-      [](const RecordDecl *RD) -> const RecordDecl * {
-    while (RD && RD->isAnonymousStructOrUnion()) {
-      const auto *R = dyn_cast<RecordDecl>(RD->getDeclContext());
-      if (!R)
-        return nullptr;
-      RD = R;
-    }
-    return RD;
-  };
-  const RecordDecl *EnclosingRD = GetNonAnonStructOrUnion(FD->getParent());
-  if (!EnclosingRD)
-    return nullptr;
-
-  DeclarationName DName(CBA->getCountedByField());
-  DeclContext::lookup_result Lookup = EnclosingRD->lookup(DName);
-
-  if (Lookup.empty())
-    return nullptr;
-
-  const NamedDecl *ND = Lookup.front();
-  if (const auto *IFD = dyn_cast<IndirectFieldDecl>(ND))
-    ND = IFD->getAnonField();
+  auto *CountDRE = cast<DeclRefExpr>(CAT->getCountExpr());
+  auto *CountDecl = CountDRE->getDecl();
+  if (const auto *IFD = dyn_cast<IndirectFieldDecl>(CountDecl))
+    CountDecl = IFD->getAnonField();
 
-  return dyn_cast<FieldDecl>(ND);
+  return dyn_cast<FieldDecl>(CountDecl);
 }
 
 void CodeGenFunction::EmitBoundsCheck(const Expr *E, const Expr *Base,
diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp
index ed684c5d57b1e..1bbccfd47c3a8 100644
--- a/clang/lib/Parse/ParseDecl.cpp
+++ b/clang/lib/Parse/ParseDecl.cpp
@@ -629,6 +629,10 @@ void Parser::ParseGNUAttributeArgs(
     ParseAttributeWithTypeArg(*AttrName, AttrNameLoc, Attrs, ScopeName,
                               ScopeLoc, Form);
     return;
+  } else if (AttrKind == ParsedAttr::AT_CountedBy) {
+    ParseBoundsAttribute(*AttrName, AttrNameLoc, Attrs, ScopeName,
+                         ScopeLoc, Form);
+    return;
   }
 
   // These may refer to the function arguments, but need to be parsed early to
@@ -3161,6 +3165,52 @@ void Parser::ParseAlignmentSpecifier(ParsedAttributes &Attrs,
   }
 }
 
+/// Bounds attributes (e.g., counted_by):
+///   AttrName '(' expression ')'
+void Parser::ParseBoundsAttribute(IdentifierInfo &AttrName, SourceLocation AttrNameLoc, ParsedAttributes &Attrs, IdentifierInfo *ScopeName,
+                                       SourceLocation ScopeLoc,
+                                       ParsedAttr::Form Form) {
+  assert(Tok.is(tok::l_paren) && "Attribute arg list not starting with '('");
+ 
+  BalancedDelimiterTracker Parens(*this, tok::l_paren);
+  Parens.consumeOpen();
+
+  if (Tok.is(tok::r_paren)) {
+    Diag(Tok.getLocation(), diag::err_argument_required_after_attribute);
+    Parens.consumeClose();
+    return;
+  }
+
+  ArgsVector ArgExprs;
+  // Don't evaluate argument when the attribute is ignored.
+  using ExpressionKind =
+      Sema::ExpressionEvaluationContextRecord::ExpressionKind;
+  EnterExpressionEvaluationContext EC(
+      Actions, Sema::ExpressionEvaluationContext::PotentiallyEvaluated, nullptr,
+      ExpressionKind::EK_BoundsAttrArgument);
+
+  ExprResult ArgExpr(
+      Actions.CorrectDelayedTyposInExpr(ParseAssignmentExpression()));
+
+  if (ArgExpr.isInvalid()) {
+    Parens.skipToEnd();
+    return;
+  }
+
+  ArgExprs.push_back(ArgExpr.get());
+  Parens.consumeClose();
+
+  ASTContext &Ctx = Actions.getASTContext();
+
+  ArgExprs.push_back(IntegerLiteral::Create(
+      Ctx, llvm::APInt(Ctx.getTypeSize(Ctx.getSizeType()), 0),
+      Ctx.getSizeType(), SourceLocation()));
+
+  Attrs.addNew(&AttrName, SourceRange(AttrNameLoc, Parens.getCloseLocation()),
+               ScopeName, ScopeLoc, ArgExprs.data(),
+               ArgExprs.size(), Form);
+}
+
 ExprResult Parser::ParseExtIntegerArgument() {
   assert(Tok.isOneOf(tok::kw__ExtInt, tok::kw__BitInt) &&
          "Not an extended int type");
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index e92fd104d78eb..8e46c4984d93d 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2315,12 +2315,6 @@ void Sema::ActOnPopScope(SourceLocation Loc, Scope *S) {
       }
       ShadowingDecls.erase(ShadowI);
     }
-
-    if (!getLangOpts().CPlusPlus && S->isClassScope()) {
-      if (auto *FD = dyn_cast<FieldDecl>(TmpD);
-          FD && FD->hasAttr<CountedByAttr>())
-        CheckCountedByAttr(S, FD);
-    }
   }
 
   llvm::sort(DeclDiags,
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 1a58cfd8e4179..eb4cd97350c6e 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -8460,133 +8460,102 @@ static void handleZeroCallUsedRegsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   D->addAttr(ZeroCallUsedRegsAttr::Create(S.Context, Kind, AL));
 }
 
-static void handleCountedByAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
-  if (!AL.isArgIdent(0)) {
-    S.Diag(AL.getLoc(), diag::err_attribute_argument_type)
-        << AL << AANT_ArgumentIdentifier;
-    return;
-  }
 
-  IdentifierLoc *IL = AL.getArgAsIdent(0);
-  CountedByAttr *CBA =
-      ::new (S.Context) CountedByAttr(S.Context, AL, IL->Ident);
-  CBA->setCountedByFieldLoc(IL->Loc);
-  D->addAttr(CBA);
+static const RecordDecl *GetEnclosingNamedOrTopAnonRecord(const FieldDecl *FD) {
+  const auto *RD = FD->getParent();
+  while (RD && RD->isAnonymousStructOrUnion()) {
+    const auto *Parent = dyn_cast<RecordDecl>(RD->getParent());
+    if (!Parent)
+      break;
+    RD = Parent;
+  }
+  return RD;
 }
 
-static const FieldDecl *
-FindFieldInTopLevelOrAnonymousStruct(const RecordDecl *RD,
-                                     const IdentifierInfo *FieldName) {
-  for (const Decl *D : RD->decls()) {
-    if (const auto *FD = dyn_cast<FieldDecl>(D))
-      if (FD->getName() == FieldName->getName())
-        return FD;
-
-    if (const auto *R = dyn_cast<RecordDecl>(D))
-      if (const FieldDecl *FD =
-              FindFieldInTopLevelOrAnonymousStruct(R, FieldName))
-        return FD;
+static bool CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
+                           llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
+  if (FD->getParent()->isUnion()) {
+    S.Diag(FD->getBeginLoc(), diag::err_counted_by_attr_in_union) << FD->getSourceRange();
+    return true;
   }
 
-  return nullptr;
-}
+  if (!E->getType()->isIntegerType() || E->getType()->isBooleanType()) {
+    S.Diag(E->getBeginLoc(),
+           diag::err_counted_by_attr_argument_not_integer)
+           << E->getSourceRange();
+    return true;
+  }
 
-bool Sema::CheckCountedByAttr(Scope *S, const FieldDecl *FD) {
   LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel =
       LangOptions::StrictFlexArraysLevelKind::IncompleteOnly;
-  if (!Decl::isFlexibleArrayMemberLike(Context, FD, FD->getType(),
+
+  if (!Decl::isFlexibleArrayMemberLike(S.getASTContext(), FD, FD->getType(),
                                        StrictFlexArraysLevel, true)) {
     // The "counted_by" attribute must be on a flexible array member.
     SourceRange SR = FD->getLocation();
-    Diag(SR.getBegin(), diag::err_counted_by_attr_not_on_flexible_array_member)
+    S.Diag(SR.getBegin(), diag::err_counted_by_attr_not_on_flexible_array_member)
         << SR;
     return true;
   }
 
-  const auto *CBA = FD->getAttr<CountedByAttr>();
-  const IdentifierInfo *FieldName = CBA->getCountedByField();
-
-  auto GetNonAnonStructOrUnion = [](const RecordDecl *RD) {
-    while (RD && !RD->getDeclName())
-      if (const auto *R = dyn_cast<RecordDecl>(RD->getDeclContext()))
-        RD = R;
-      else
-        break;
-
-    return RD;
-  };
-
-  const RecordDecl *EnclosingRD = GetNonAnonStructOrUnion(FD->getParent());
-  const FieldDecl *CountFD =
-      FindFieldInTopLevelOrAnonymousStruct(EnclosingRD, FieldName);
+  auto *DRE = dyn_cast<DeclRefExpr>(E); 
+  if (!DRE) {
+    S.Diag(E->getBeginLoc(),
+           diag::err_counted_by_attr_only_support_simple_decl_reference)
+           << E->getSourceRange();
+    return true;
+  }
 
+  auto *CountDecl = DRE->getDecl();
+  FieldDecl *CountFD = dyn_cast<FieldDecl>(CountDecl);
+  if (auto *IFD = dyn_cast<IndirectFieldDecl>(CountDecl)) {
+    CountFD = IFD->getAnonField();
+  }
   if (!CountFD) {
-    DeclarationNameInfo NameInfo(FieldName,
-                                 CBA->getCountedByFieldLoc().getBegin());
-    LookupResult MemResult(*this, NameInfo, Sema::LookupMemberName);
-    LookupName(MemResult, S);
-
-    if (!MemResult.empty()) {
-      SourceRange SR = CBA->getCountedByFieldLoc();
-      Diag(SR.getBegin(), diag::err_flexible_array_count_not_in_same_struct)
-          << CBA->getCountedByField() << SR;
-
-      if (auto *ND = MemResult.getAsSingle<NamedDecl>()) {
-        SR = ND->getLocation();
-        Diag(SR.getBegin(), diag::note_flexible_array_counted_by_attr_field)
-            << ND << SR;
-      }
+    S.Diag(E->getBeginLoc(), diag::err_counted_by_must_be_in_structure)
+      << CountDecl << E->getSourceRange();
+    
+    S.Diag(CountDecl->getBeginLoc(), diag::note_flexible_array_counted_by_attr_field)
+            << CountDecl << CountDecl->getSourceRange();
+    return true;
+  }
 
+  if (FD->getParent() != CountFD->getParent()) {
+    if (CountFD->getParent()->isUnion()) {
+      S.Diag(CountFD->getBeginLoc(), diag::err_counted_by_attr_refer_to_union)
+            << CountFD->getSourceRange();
       return true;
-    } else {
-      // The "counted_by" field needs to exist in the struct.
-      LookupResult OrdResult(*this, NameInfo, Sema::LookupOrdinaryName);
-      LookupName(OrdResult, S);
-
-      if (!OrdResult.empty()) {
-        SourceRange SR = FD->getLocation();
-        Diag(SR.getBegin(), diag::err_counted_by_must_be_in_structure)
-            << FieldName << SR;
-
-        if (auto *ND = OrdResult.getAsSingle<NamedDecl>()) {
-          SR = ND->getLocation();
-          Diag(SR.getBegin(), diag::note_flexible_array_counted_by_attr_field)
-              << ND << SR;
-        }
-
-        return true;
-      }
     }
+    // FIXME: whether CountRD is an anonymous struct is not determined at this point.
+    // check later if `FD` and `CountFD` are in the same enclosing named struct.
+    auto *RD = GetEnclosingNamedOrTopAnonRecord(FD);
+    auto *CountRD = GetEnclosingNamedOrTopAnonRecord(CountFD);
 
-    CXXScopeSpec SS;
-    DeclFilterCCC<FieldDecl> Filter(FieldName);
-    return DiagnoseEmptyLookup(S, SS, MemResult, Filter, nullptr, std::nullopt,
-                               const_cast<DeclContext *>(FD->getDeclContext()));
+    if (RD != CountRD) {
+      S.Diag(CountFD->getBeginLoc(), diag::err_flexible_array_count_not_in_same_struct)
+          << CountFD << CountFD->getSourceRange();
+      return true;
+    }
   }
 
-  if (CountFD->hasAttr<CountedByAttr>()) {
-    // The "counted_by" field can't point to the flexible array member.
-    SourceRange SR = CBA->getCountedByFieldLoc();
-    Diag(SR.getBegin(), diag::err_counted_by_attr_refers_to_flexible_array)
-        << CBA->getCountedByField() << SR;
-    return true;
-  }
+  Decls.push_back(TypeCoupledDeclRefInfo(FD, /*IsDref*/false));
+  return false;
+}
 
-  if (!CountFD->getType()->isIntegerType() ||
-      CountFD->getType()->isBooleanType()) {
-    // The "counted_by" field must have an integer type.
-    SourceRange SR = CBA->getCountedByFieldLoc();
-    Diag(SR.getBegin(),
-         diag::err_flexible_array_counted_by_attr_field_not_integer)
-        << CBA->getCountedByField() << SR;
+static void handleCountedByAttrField(Sema &S, Decl *D, const ParsedAttr &AL) {
+  auto *FD = dyn_cast<FieldDecl>(D);
+  assert(FD);
 
-    SR = CountFD->getLocation();
-    Diag(SR.getBegin(), diag::note_flexible_array_counted_by_attr_field)
-        << CountFD << SR;
-    return true;
-  }
+  auto *CountExpr = AL.getArgAsExpr(0);
+  if (!CountExpr)
+    return;
 
-  return false;
+  llvm::SmallVector<TypeCoupledDeclRefInfo, 1> Decls;
+  if (CheckCountExpr(S, FD, CountExpr, Decls))
+    return;
+
+  QualType CAT = S.BuildCountAttributedArrayType(FD->getType(), CountExpr, Decls);
+  FD->setType(CAT);
 }
 
 static void handleFunctionReturnThunksAttr(Sema &S, Decl *D,
@@ -9550,7 +9519,7 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
     break;
 
   case ParsedAttr::AT_CountedBy:
-    handleCountedByAttr(S, D, AL);
+    handleCountedByAttrField(S, D, AL);
     break;
 
   // Microsoft attributes:
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ca8c26afb48cd..fac27fd6255fc 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -2869,6 +2869,23 @@ Sema::ActOnIdExpression(Scope *S, CXXScopeSpec &SS,
   // This is guaranteed from this point on.
   assert(!R.empty() || ADL);
 
+  // BoundsSafety: This specially handles arguments of bounds attributes
+  // appertains to a type of C struct field such that the name lookup
+  // within a struct finds the member name, which is not the case for other
+  // contexts in C.
+  ExpressionEvaluationContextRecord &LastRecord = ExprEvalContexts.back();
+  using ExpressionKind = ExpressionEvaluationContextRecord::ExpressionKind;
+  if (LastRecord.ExprContext == ExpressionKind::EK_BoundsAttrArgument &&
+      !getLangOpts().CPlusPlus && S->isClassScope()) {
+    // See if this is reference to a field of struct.
+    if (auto *VD = dyn_cast<ValueDecl>(R.getFoundDecl())) {
+      QualType type = VD->getType().getNonReferenceType();
+      // This will eventually be translated into MemberExpr upon
+      // the use of instantiated struct fields.
+      return BuildDeclRefExpr(VD, type, VK_PRValue, NameLoc);
+    }
+  }
+
   // Check whether this might be a C++ implicit instance member access.
   // C++ [class.mfct.non-static]p3:
   //   When an id-expression that is not part of a class member access
@@ -2893,7 +2910,7 @@ Sema::ActOnIdExpression(Scope *S, CXXScopeSpec &SS,
   // to get this right here so that we don't end up making a
   // spuriously dependent expression if we're inside a dependent
   // instance method.
-  if (!R.empty() && (*R.begin())->isCXXClassMember()) {
+  if (getLangOpts().CPlusPlus && !R.empty() && (*R.begin())->isCXXClassMember()) {
     bool MightBeImplicitMember;
     if (!IsAddressOfOperand)
       MightBeImplicitMember = true;
@@ -3549,7 +3566,8 @@ ExprResult Sema::BuildDeclarationNameExpr(
   case Decl::Field:
   case Decl::IndirectField:
   case Decl::ObjCIvar:
-    assert(getLangOpts().CPlusPlus && "building reference to field in C?");
+    assert((getLangOpts().CPlusPlus || isBoundsAttrArgument())
+           && "building reference to field in C?");
 
     // These can't have reference type in well-formed programs, but
     // for internal consistency we do this anyway.
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index a376f20fa4f4e..e613f9d2f98a7 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -9606,6 +9606,17 @@ QualType Sema::BuildTypeofExprType(Expr *E, TypeOfKind Kind) {
   return Context.getTypeOfExprType(E, Kind);
 }
 
+QualType Sema::BuildCountAttributedArrayType(QualType WrappedTy, Expr *CountExpr,
+                                             const llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
+  assert(WrappedTy->isIncompleteArrayType());
+
+  /// When the resulting expression is invalid, we still create the AST using
+  /// the original count expression for the sake of AST dump.
+  return Context.getCountAttributedType(
+      WrappedTy, CountExpr, /*CountInBytes*/false, /*OrNull*/false,
+      llvm::ArrayRef(Decls.begin(), Decls.end()));
+}
+
 /// getDecltypeForExpr - Given an expr, will return the decltype for
 /// that expression, according to the rules in C++11
 /// [dcl.type.simple]p4 and C++11 [expr.lambda.prim]p18.

>From a47b4f66a2c16bc0bd4a5013b584d29d6ead5770 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Mon, 22 Jan 2024 16:16:42 -0800
Subject: [PATCH 03/12] clang-format

---
 clang/include/clang/Parse/Parser.h |  6 ++--
 clang/include/clang/Sema/Sema.h    | 16 +++++++----
 clang/lib/Parse/ParseDecl.cpp      | 18 ++++++------
 clang/lib/Sema/SemaDeclAttr.cpp    | 45 +++++++++++++++++-------------
 clang/lib/Sema/SemaExpr.cpp        |  7 +++--
 clang/lib/Sema/SemaType.cpp        |  7 +++--
 6 files changed, 55 insertions(+), 44 deletions(-)

diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index f88191ef5f52c..65f2e67fe77ad 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3055,10 +3055,8 @@ class Parser : public CodeCompletionHandler {
                                  ParsedAttr::Form Form);
 
   void ParseBoundsAttribute(IdentifierInfo &AttrName,
-                            SourceLocation AttrNameLoc,
-                            ParsedAttributes &Attrs,
-                            IdentifierInfo *ScopeName,
-                            SourceLocation ScopeLoc,
+                            SourceLocation AttrNameLoc, ParsedAttributes &Attrs,
+                            IdentifierInfo *ScopeName, SourceLocation ScopeLoc,
                             ParsedAttr::Form Form);
 
   void ParseTypeofSpecifier(DeclSpec &DS);
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 027afb5d703d1..849692079cd89 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -1341,7 +1341,10 @@ class Sema final {
     /// \brief Describes whether we are in an expression constext which we have
     /// to handle differently.
     enum ExpressionKind {
-      EK_Decltype, EK_TemplateArgument, EK_BoundsAttrArgument, EK_Other
+      EK_Decltype,
+      EK_TemplateArgument,
+      EK_BoundsAttrArgument,
+      EK_Other
     } ExprContext;
 
     // A context can be nested in both a discarded statement context and
@@ -1437,11 +1440,11 @@ class Sema final {
   getCurrentMangleNumberContext(const DeclContext *DC);
 
   bool isBoundsAttrArgument() const {
-    return ExprEvalContexts.back().ExprContext == 
-        ExpressionEvaluationContextRecord::ExpressionKind::EK_BoundsAttrArgument; 
+    return ExprEvalContexts.back().ExprContext ==
+           ExpressionEvaluationContextRecord::ExpressionKind::
+               EK_BoundsAttrArgument;
   }
 
-
   /// SpecialMemberOverloadResult - The overloading result for a special member
   /// function.
   ///
@@ -2617,8 +2620,9 @@ class Sema final {
   QualType BuiltinChangeSignedness(QualType BaseType, UTTKind UKind,
                                    SourceLocation Loc);
 
-  QualType BuildCountAttributedArrayType(QualType WrappedTy, Expr *CountExpr,
-                              const llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls);
+  QualType BuildCountAttributedArrayType(
+      QualType WrappedTy, Expr *CountExpr,
+      const llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls);
 
   //===--------------------------------------------------------------------===//
   // Symbol table / Decl tracking callbacks: SemaDecl.cpp.
diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp
index 1bbccfd47c3a8..83a1e09ebd95f 100644
--- a/clang/lib/Parse/ParseDecl.cpp
+++ b/clang/lib/Parse/ParseDecl.cpp
@@ -630,8 +630,8 @@ void Parser::ParseGNUAttributeArgs(
                               ScopeLoc, Form);
     return;
   } else if (AttrKind == ParsedAttr::AT_CountedBy) {
-    ParseBoundsAttribute(*AttrName, AttrNameLoc, Attrs, ScopeName,
-                         ScopeLoc, Form);
+    ParseBoundsAttribute(*AttrName, AttrNameLoc, Attrs, ScopeName, ScopeLoc,
+                         Form);
     return;
   }
 
@@ -3167,11 +3167,14 @@ void Parser::ParseAlignmentSpecifier(ParsedAttributes &Attrs,
 
 /// Bounds attributes (e.g., counted_by):
 ///   AttrName '(' expression ')'
-void Parser::ParseBoundsAttribute(IdentifierInfo &AttrName, SourceLocation AttrNameLoc, ParsedAttributes &Attrs, IdentifierInfo *ScopeName,
-                                       SourceLocation ScopeLoc,
-                                       ParsedAttr::Form Form) {
+void Parser::ParseBoundsAttribute(IdentifierInfo &AttrName,
+                                  SourceLocation AttrNameLoc,
+                                  ParsedAttributes &Attrs,
+                                  IdentifierInfo *ScopeName,
+                                  SourceLocation ScopeLoc,
+                                  ParsedAttr::Form Form) {
   assert(Tok.is(tok::l_paren) && "Attribute arg list not starting with '('");
- 
+
   BalancedDelimiterTracker Parens(*this, tok::l_paren);
   Parens.consumeOpen();
 
@@ -3207,8 +3210,7 @@ void Parser::ParseBoundsAttribute(IdentifierInfo &AttrName, SourceLocation AttrN
       Ctx.getSizeType(), SourceLocation()));
 
   Attrs.addNew(&AttrName, SourceRange(AttrNameLoc, Parens.getCloseLocation()),
-               ScopeName, ScopeLoc, ArgExprs.data(),
-               ArgExprs.size(), Form);
+               ScopeName, ScopeLoc, ArgExprs.data(), ArgExprs.size(), Form);
 }
 
 ExprResult Parser::ParseExtIntegerArgument() {
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index eb4cd97350c6e..d87fb830564bd 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -8460,7 +8460,6 @@ static void handleZeroCallUsedRegsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   D->addAttr(ZeroCallUsedRegsAttr::Create(S.Context, Kind, AL));
 }
 
-
 static const RecordDecl *GetEnclosingNamedOrTopAnonRecord(const FieldDecl *FD) {
   const auto *RD = FD->getParent();
   while (RD && RD->isAnonymousStructOrUnion()) {
@@ -8472,17 +8471,18 @@ static const RecordDecl *GetEnclosingNamedOrTopAnonRecord(const FieldDecl *FD) {
   return RD;
 }
 
-static bool CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
-                           llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
+static bool
+CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
+               llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
   if (FD->getParent()->isUnion()) {
-    S.Diag(FD->getBeginLoc(), diag::err_counted_by_attr_in_union) << FD->getSourceRange();
+    S.Diag(FD->getBeginLoc(), diag::err_counted_by_attr_in_union)
+        << FD->getSourceRange();
     return true;
   }
 
   if (!E->getType()->isIntegerType() || E->getType()->isBooleanType()) {
-    S.Diag(E->getBeginLoc(),
-           diag::err_counted_by_attr_argument_not_integer)
-           << E->getSourceRange();
+    S.Diag(E->getBeginLoc(), diag::err_counted_by_attr_argument_not_integer)
+        << E->getSourceRange();
     return true;
   }
 
@@ -8493,16 +8493,17 @@ static bool CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
                                        StrictFlexArraysLevel, true)) {
     // The "counted_by" attribute must be on a flexible array member.
     SourceRange SR = FD->getLocation();
-    S.Diag(SR.getBegin(), diag::err_counted_by_attr_not_on_flexible_array_member)
+    S.Diag(SR.getBegin(),
+           diag::err_counted_by_attr_not_on_flexible_array_member)
         << SR;
     return true;
   }
 
-  auto *DRE = dyn_cast<DeclRefExpr>(E); 
+  auto *DRE = dyn_cast<DeclRefExpr>(E);
   if (!DRE) {
     S.Diag(E->getBeginLoc(),
            diag::err_counted_by_attr_only_support_simple_decl_reference)
-           << E->getSourceRange();
+        << E->getSourceRange();
     return true;
   }
 
@@ -8513,32 +8514,35 @@ static bool CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
   }
   if (!CountFD) {
     S.Diag(E->getBeginLoc(), diag::err_counted_by_must_be_in_structure)
-      << CountDecl << E->getSourceRange();
-    
-    S.Diag(CountDecl->getBeginLoc(), diag::note_flexible_array_counted_by_attr_field)
-            << CountDecl << CountDecl->getSourceRange();
+        << CountDecl << E->getSourceRange();
+
+    S.Diag(CountDecl->getBeginLoc(),
+           diag::note_flexible_array_counted_by_attr_field)
+        << CountDecl << CountDecl->getSourceRange();
     return true;
   }
 
   if (FD->getParent() != CountFD->getParent()) {
     if (CountFD->getParent()->isUnion()) {
       S.Diag(CountFD->getBeginLoc(), diag::err_counted_by_attr_refer_to_union)
-            << CountFD->getSourceRange();
+          << CountFD->getSourceRange();
       return true;
     }
-    // FIXME: whether CountRD is an anonymous struct is not determined at this point.
-    // check later if `FD` and `CountFD` are in the same enclosing named struct.
+    // FIXME: whether CountRD is an anonymous struct is not determined at this
+    // point. check later if `FD` and `CountFD` are in the same enclosing named
+    // struct.
     auto *RD = GetEnclosingNamedOrTopAnonRecord(FD);
     auto *CountRD = GetEnclosingNamedOrTopAnonRecord(CountFD);
 
     if (RD != CountRD) {
-      S.Diag(CountFD->getBeginLoc(), diag::err_flexible_array_count_not_in_same_struct)
+      S.Diag(CountFD->getBeginLoc(),
+             diag::err_flexible_array_count_not_in_same_struct)
           << CountFD << CountFD->getSourceRange();
       return true;
     }
   }
 
-  Decls.push_back(TypeCoupledDeclRefInfo(FD, /*IsDref*/false));
+  Decls.push_back(TypeCoupledDeclRefInfo(FD, /*IsDref*/ false));
   return false;
 }
 
@@ -8554,7 +8558,8 @@ static void handleCountedByAttrField(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (CheckCountExpr(S, FD, CountExpr, Decls))
     return;
 
-  QualType CAT = S.BuildCountAttributedArrayType(FD->getType(), CountExpr, Decls);
+  QualType CAT =
+      S.BuildCountAttributedArrayType(FD->getType(), CountExpr, Decls);
   FD->setType(CAT);
 }
 
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index fac27fd6255fc..b20828e6a906d 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -2910,7 +2910,8 @@ Sema::ActOnIdExpression(Scope *S, CXXScopeSpec &SS,
   // to get this right here so that we don't end up making a
   // spuriously dependent expression if we're inside a dependent
   // instance method.
-  if (getLangOpts().CPlusPlus && !R.empty() && (*R.begin())->isCXXClassMember()) {
+  if (getLangOpts().CPlusPlus && !R.empty() &&
+      (*R.begin())->isCXXClassMember()) {
     bool MightBeImplicitMember;
     if (!IsAddressOfOperand)
       MightBeImplicitMember = true;
@@ -3566,8 +3567,8 @@ ExprResult Sema::BuildDeclarationNameExpr(
   case Decl::Field:
   case Decl::IndirectField:
   case Decl::ObjCIvar:
-    assert((getLangOpts().CPlusPlus || isBoundsAttrArgument())
-           && "building reference to field in C?");
+    assert((getLangOpts().CPlusPlus || isBoundsAttrArgument()) &&
+           "building reference to field in C?");
 
     // These can't have reference type in well-formed programs, but
     // for internal consistency we do this anyway.
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index e613f9d2f98a7..87ff9f79be7ed 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -9606,14 +9606,15 @@ QualType Sema::BuildTypeofExprType(Expr *E, TypeOfKind Kind) {
   return Context.getTypeOfExprType(E, Kind);
 }
 
-QualType Sema::BuildCountAttributedArrayType(QualType WrappedTy, Expr *CountExpr,
-                                             const llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
+QualType Sema::BuildCountAttributedArrayType(
+    QualType WrappedTy, Expr *CountExpr,
+    const llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
   assert(WrappedTy->isIncompleteArrayType());
 
   /// When the resulting expression is invalid, we still create the AST using
   /// the original count expression for the sake of AST dump.
   return Context.getCountAttributedType(
-      WrappedTy, CountExpr, /*CountInBytes*/false, /*OrNull*/false,
+      WrappedTy, CountExpr, /*CountInBytes*/ false, /*OrNull*/ false,
       llvm::ArrayRef(Decls.begin(), Decls.end()));
 }
 

>From e7fe1d7a61777b4b5111dff5ec0c217096822c78 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Mon, 22 Jan 2024 16:45:24 -0800
Subject: [PATCH 04/12] Special lookup for counted_by argument in structs

---
 clang/lib/Sema/SemaExpr.cpp | 38 ++++++++++++++++++++-----------------
 1 file changed, 21 insertions(+), 17 deletions(-)

diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index b20828e6a906d..2eef46149118f 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -2737,6 +2737,27 @@ Sema::ActOnIdExpression(Scope *S, CXXScopeSpec &SS,
     return ActOnDependentIdExpression(SS, TemplateKWLoc, NameInfo,
                                       IsAddressOfOperand, TemplateArgs);
 
+  // BoundsSafety: This specially handles arguments of bounds attributes
+  // appertains to a type of C struct field such that the name lookup
+  // within a struct finds the member name, which is not the case for other
+  // contexts in C.
+  ExpressionEvaluationContextRecord &LastRecord = ExprEvalContexts.back();
+  using ExpressionKind = ExpressionEvaluationContextRecord::ExpressionKind;
+  if (LastRecord.ExprContext == ExpressionKind::EK_BoundsAttrArgument &&
+      !getLangOpts().CPlusPlus && S->isClassScope()) {
+    // See if this is reference to a field of struct.
+    LookupResult R(*this, NameInfo, LookupMemberName);
+    if (LookupQualifiedName(R, S->getEntity(), SS)) {
+
+      if (auto *VD = dyn_cast<ValueDecl>(R.getFoundDecl())) {
+        QualType type = VD->getType().getNonReferenceType();
+        // This will eventually be translated into MemberExpr upon
+        // the use of instantiated struct fields.
+        return BuildDeclRefExpr(VD, type, VK_PRValue, NameLoc);
+      }
+    }
+  }
+
   // Perform the required lookup.
   LookupResult R(*this, NameInfo,
                  (Id.getKind() == UnqualifiedIdKind::IK_ImplicitSelfParam)
@@ -2869,23 +2890,6 @@ Sema::ActOnIdExpression(Scope *S, CXXScopeSpec &SS,
   // This is guaranteed from this point on.
   assert(!R.empty() || ADL);
 
-  // BoundsSafety: This specially handles arguments of bounds attributes
-  // appertains to a type of C struct field such that the name lookup
-  // within a struct finds the member name, which is not the case for other
-  // contexts in C.
-  ExpressionEvaluationContextRecord &LastRecord = ExprEvalContexts.back();
-  using ExpressionKind = ExpressionEvaluationContextRecord::ExpressionKind;
-  if (LastRecord.ExprContext == ExpressionKind::EK_BoundsAttrArgument &&
-      !getLangOpts().CPlusPlus && S->isClassScope()) {
-    // See if this is reference to a field of struct.
-    if (auto *VD = dyn_cast<ValueDecl>(R.getFoundDecl())) {
-      QualType type = VD->getType().getNonReferenceType();
-      // This will eventually be translated into MemberExpr upon
-      // the use of instantiated struct fields.
-      return BuildDeclRefExpr(VD, type, VK_PRValue, NameLoc);
-    }
-  }
-
   // Check whether this might be a C++ implicit instance member access.
   // C++ [class.mfct.non-static]p3:
   //   When an id-expression that is not part of a class member access

>From 4dc0386ed93929c29c80131441df3b201b8c32b3 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Mon, 22 Jan 2024 16:55:14 -0800
Subject: [PATCH 05/12] Adjust test/Sema/attr-counted-by.c to match the default
 diags generated by the expression parser

---
 clang/test/Sema/attr-counted-by.c | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/clang/test/Sema/attr-counted-by.c b/clang/test/Sema/attr-counted-by.c
index f14da9c77fa8b..c26e8ce91357b 100644
--- a/clang/test/Sema/attr-counted-by.c
+++ b/clang/test/Sema/attr-counted-by.c
@@ -11,16 +11,16 @@ struct not_found {
 
 struct no_found_count_not_in_substruct {
   unsigned long flags;
-  unsigned char count; // expected-note {{field 'count' declared here}}
+  unsigned char count;
   struct A {
     int dummy;
-    int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
+    int array[] __counted_by(count); // expected-error {{use of undeclared identifier 'count'}}
   } a;
 };
 
 struct not_found_suggest {
-  int bork; // expected-note {{'bork' declared here}}
-  struct bar *fam[] __counted_by(blork); // expected-error {{use of undeclared identifier 'blork'; did you mean 'bork'?}}
+  int bork;
+  struct bar *fam[] __counted_by(blork); // expected-error {{use of undeclared identifier 'blork'}}
 };
 
 int global; // expected-note {{'global' declared here}}
@@ -32,17 +32,17 @@ struct found_outside_of_struct {
 
 struct self_referrential {
   int bork;
-  struct bar *self[] __counted_by(self); // expected-error {{'counted_by' cannot refer to the flexible array 'self'}}
+  struct bar *self[] __counted_by(self); // expected-error {{use of undeclared identifier 'self'}}
 };
 
 struct non_int_count {
-  double dbl_count; // expected-note {{field 'dbl_count' declared here}}
-  struct bar *fam[] __counted_by(dbl_count); // expected-error {{field 'dbl_count' in 'counted_by' must be a non-boolean integer type}}
+  double dbl_count;
+  struct bar *fam[] __counted_by(dbl_count); // expected-error {{'counted_by' requires a non-boolean integer type argument}}
 };
 
 struct array_of_ints_count {
-  int integers[2]; // expected-note {{field 'integers' declared here}}
-  struct bar *fam[] __counted_by(integers); // expected-error {{field 'integers' in 'counted_by' must be a non-boolean integer type}}
+  int integers[2];
+  struct bar *fam[] __counted_by(integers); // expected-error {{'counted_by' requires a non-boolean integer type argument}}
 };
 
 struct not_a_fam {
@@ -58,7 +58,7 @@ struct not_a_c99_fam {
 struct annotated_with_anon_struct {
   unsigned long flags;
   struct {
-    unsigned char count; // expected-note {{'count' declared here}}
-    int array[] __counted_by(crount); // expected-error {{use of undeclared identifier 'crount'; did you mean 'count'?}}
+    unsigned char count;
+    int array[] __counted_by(crount); // expected-error {{use of undeclared identifier 'crount'}}
   };
 };

>From fa8be19edcffc0b4014bff064e82a978c29eb784 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Tue, 30 Jan 2024 12:54:47 -0800
Subject: [PATCH 06/12] Fix name lookup issues involving anonymous struct

---
 clang/lib/AST/ASTContext.cpp    |  1 -
 clang/lib/Parse/ParseDecl.cpp   | 24 ++++++++++++++++++++++++
 clang/lib/Sema/SemaDeclAttr.cpp |  4 ++--
 clang/lib/Sema/SemaExpr.cpp     |  4 ++--
 4 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 43d2da0e60f36..fd43f89cf4481 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -3113,7 +3113,6 @@ QualType ASTContext::getCountAttributedType(
     QualType WrappedTy, Expr *CountExpr, bool CountInBytes, bool OrNull,
     ArrayRef<TypeCoupledDeclRefInfo> DependentDecls) const {
   assert(WrappedTy->isPointerType() || WrappedTy->isArrayType());
-  assert(CountExpr->isPRValue());
 
   llvm::FoldingSetNodeID ID;
   CountAttributedType::Profile(ID, WrappedTy, CountExpr, CountInBytes, OrNull);
diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp
index 83a1e09ebd95f..9fbb1270b364a 100644
--- a/clang/lib/Parse/ParseDecl.cpp
+++ b/clang/lib/Parse/ParseDecl.cpp
@@ -4712,6 +4712,30 @@ void Parser::ParseStructDeclaration(
     } else
       DeclaratorInfo.D.SetIdentifier(nullptr, Tok.getLocation());
 
+    if (DS.getTypeSpecType() == DeclSpec::TST_struct) {
+      auto *decl = DS.getRepAsDecl();
+      auto *RD = dyn_cast<RecordDecl>(decl);
+      // Here, we now know that the unnamed struct is not an anonymous struct.
+      // Report an error if a counted_by attribute refers to a field in a different struct.
+      if (RD && RD->getName().empty()) {
+        assert(!RD->isAnonymousStructOrUnion());
+        for (auto *I : RD->decls()) {
+          if (auto *VD = dyn_cast<ValueDecl>(I)) {
+            if (auto *CAT = VD->getType()->getAs<CountAttributedType>()) {
+              for (const auto &dd : CAT->dependent_decls()) {
+                if (!RD->containsDecl(dd.getDecl())) {
+                  Diag(VD->getBeginLoc(), diag::err_flexible_array_count_not_in_same_struct)
+                    << dd.getDecl();
+                  Diag(dd.getDecl()->getBeginLoc(), diag::note_flexible_array_counted_by_attr_field)
+                    << dd.getDecl();
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+
     if (TryConsumeToken(tok::colon)) {
       ExprResult Res(ParseConstantExpression());
       if (Res.isInvalid())
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index d87fb830564bd..eef474f5fec9c 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -8462,7 +8462,7 @@ static void handleZeroCallUsedRegsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
 
 static const RecordDecl *GetEnclosingNamedOrTopAnonRecord(const FieldDecl *FD) {
   const auto *RD = FD->getParent();
-  while (RD && RD->isAnonymousStructOrUnion()) {
+  while (RD && (RD->isAnonymousStructOrUnion() || RD->getName().empty())) {
     const auto *Parent = dyn_cast<RecordDecl>(RD->getParent());
     if (!Parent)
       break;
@@ -8542,7 +8542,7 @@ CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
     }
   }
 
-  Decls.push_back(TypeCoupledDeclRefInfo(FD, /*IsDref*/ false));
+  Decls.push_back(TypeCoupledDeclRefInfo(CountFD, /*IsDref*/ false));
   return false;
 }
 
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 2eef46149118f..a49f39a64f212 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -2747,8 +2747,8 @@ Sema::ActOnIdExpression(Scope *S, CXXScopeSpec &SS,
       !getLangOpts().CPlusPlus && S->isClassScope()) {
     // See if this is reference to a field of struct.
     LookupResult R(*this, NameInfo, LookupMemberName);
-    if (LookupQualifiedName(R, S->getEntity(), SS)) {
-
+    // LookupParsedName handles a name lookup from within anonymous struct.
+    if (LookupParsedName(R, S, &SS)) {
       if (auto *VD = dyn_cast<ValueDecl>(R.getFoundDecl())) {
         QualType type = VD->getType().getNonReferenceType();
         // This will eventually be translated into MemberExpr upon

>From b8ab70741d5c8e28c5bd8111582fe5e8844c8f35 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Tue, 30 Jan 2024 13:06:26 -0800
Subject: [PATCH 07/12] Fix diags

---
 clang/lib/Sema/SemaDeclAttr.cpp   | 5 ++++-
 clang/test/Sema/attr-counted-by.c | 4 ++--
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index eef474f5fec9c..9dfc412e14aef 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -8535,8 +8535,11 @@ CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
     auto *CountRD = GetEnclosingNamedOrTopAnonRecord(CountFD);
 
     if (RD != CountRD) {
-      S.Diag(CountFD->getBeginLoc(),
+      S.Diag(E->getBeginLoc(),
              diag::err_flexible_array_count_not_in_same_struct)
+          << CountFD << E->getSourceRange();
+      S.Diag(CountFD->getBeginLoc(),
+             diag::note_flexible_array_counted_by_attr_field)
           << CountFD << CountFD->getSourceRange();
       return true;
     }
diff --git a/clang/test/Sema/attr-counted-by.c b/clang/test/Sema/attr-counted-by.c
index c26e8ce91357b..a3b08df730994 100644
--- a/clang/test/Sema/attr-counted-by.c
+++ b/clang/test/Sema/attr-counted-by.c
@@ -11,10 +11,10 @@ struct not_found {
 
 struct no_found_count_not_in_substruct {
   unsigned long flags;
-  unsigned char count;
+  unsigned char count; // expected-note {{'count' declared here}}
   struct A {
     int dummy;
-    int array[] __counted_by(count); // expected-error {{use of undeclared identifier 'count'}}
+    int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
   } a;
 };
 

>From 3b73e2c3ed5d6bc0785b360004d43285bda64cf9 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Tue, 30 Jan 2024 13:34:24 -0800
Subject: [PATCH 08/12] Adjust __bdos and array-checks sanitizer to use
 CountAttributedType

---
 clang/lib/CodeGen/CGBuiltin.cpp | 6 +++---
 clang/lib/CodeGen/CGExpr.cpp    | 2 +-
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 998fcc3af5817..8357574bbffe4 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -857,7 +857,7 @@ static unsigned CountCountedByAttrs(const RecordDecl *RD) {
 
   for (const Decl *D : RD->decls()) {
     if (const auto *FD = dyn_cast<FieldDecl>(D);
-        FD && FD->hasAttr<CountedByAttr>()) {
+        FD && FD->getType()->getAs<CountAttributedType>()) {
       return ++Num;
     }
 
@@ -955,7 +955,7 @@ CodeGenFunction::emitFlexibleArrayMemberSize(const Expr *E, unsigned Type,
     //         };
     //    };
     //
-    // We don't konw which 'count' to use in this scenario:
+    // We don't know which 'count' to use in this scenario:
     //
     //     size_t get_size(struct union_of_fams *p) {
     //         return __builtin_dynamic_object_size(p, 1);
@@ -974,7 +974,7 @@ CodeGenFunction::emitFlexibleArrayMemberSize(const Expr *E, unsigned Type,
       FindFlexibleArrayMemberField(Ctx, OuterRD, FAMName, Offset);
   Offset = Ctx.toCharUnitsFromBits(Offset).getQuantity();
 
-  if (!FAMDecl || !FAMDecl->hasAttr<CountedByAttr>())
+  if (!FAMDecl || !FAMDecl->getType()->getAs<CountAttributedType>())
     // No flexible array member found or it doesn't have the "counted_by"
     // attribute.
     return nullptr;
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 9cf0e8984de08..27cd9daab30bd 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -4232,7 +4232,7 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
       if (const auto *ME = dyn_cast<MemberExpr>(Array);
           ME &&
           ME->isFlexibleArrayMemberLike(getContext(), StrictFlexArraysLevel) &&
-          ME->getMemberDecl()->hasAttr<CountedByAttr>()) {
+          ME->getMemberDecl()->getType()->getAs<CountAttributedType>()) {
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
         if (const FieldDecl *CountFD = FindCountedByField(FAMDecl)) {
           if (std::optional<int64_t> Diff =

>From 66605b6e0c9d955e7550abb04e114f8ebaf24967 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Tue, 30 Jan 2024 15:40:40 -0800
Subject: [PATCH 09/12] clang-format

---
 clang/lib/Parse/ParseDecl.cpp | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp
index 9fbb1270b364a..2e4512e82e86e 100644
--- a/clang/lib/Parse/ParseDecl.cpp
+++ b/clang/lib/Parse/ParseDecl.cpp
@@ -4716,7 +4716,8 @@ void Parser::ParseStructDeclaration(
       auto *decl = DS.getRepAsDecl();
       auto *RD = dyn_cast<RecordDecl>(decl);
       // Here, we now know that the unnamed struct is not an anonymous struct.
-      // Report an error if a counted_by attribute refers to a field in a different struct.
+      // Report an error if a counted_by attribute refers to a field in a
+      // different struct.
       if (RD && RD->getName().empty()) {
         assert(!RD->isAnonymousStructOrUnion());
         for (auto *I : RD->decls()) {
@@ -4724,10 +4725,12 @@ void Parser::ParseStructDeclaration(
             if (auto *CAT = VD->getType()->getAs<CountAttributedType>()) {
               for (const auto &dd : CAT->dependent_decls()) {
                 if (!RD->containsDecl(dd.getDecl())) {
-                  Diag(VD->getBeginLoc(), diag::err_flexible_array_count_not_in_same_struct)
-                    << dd.getDecl();
-                  Diag(dd.getDecl()->getBeginLoc(), diag::note_flexible_array_counted_by_attr_field)
-                    << dd.getDecl();
+                  Diag(VD->getBeginLoc(),
+                       diag::err_flexible_array_count_not_in_same_struct)
+                      << dd.getDecl();
+                  Diag(dd.getDecl()->getBeginLoc(),
+                       diag::note_flexible_array_counted_by_attr_field)
+                      << dd.getDecl();
                 }
               }
             }

>From d0829a6c5a51c4f6fc172f0e8b2647449c923e82 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Thu, 1 Feb 2024 10:24:46 -0800
Subject: [PATCH 10/12] Fix ASTNodeImporter::VisitCountAttributedType

---
 clang/lib/AST/ASTImporter.cpp | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index 292ee69818df7..d05347e72aeb5 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1533,10 +1533,20 @@ ASTNodeImporter::VisitCountAttributedType(const CountAttributedType *T) {
   if (!ToWrappedTypeOrErr)
     return ToWrappedTypeOrErr.takeError();
 
-  // FIXME: Handle CoupledDecls correctly
+  Error Err = Error::success();
+  Expr *CountExpr = importChecked(Err, T->getCountExpr());
+
+  SmallVector<TypeCoupledDeclRefInfo, 1> CoupledDecls;
+  for (auto TI : T->dependent_decls()) {
+    Expected<ValueDecl *> ToDeclOrErr = import(TI.getDecl());
+    if (!ToDeclOrErr)
+      return ToDeclOrErr.takeError();
+    CoupledDecls.emplace_back(*ToDeclOrErr, TI.isDeref());
+  }
+
   return Importer.getToContext().getCountAttributedType(
-      *ToWrappedTypeOrErr, T->getCountExpr(), T->isCountInBytes(),
-      T->isOrNull(), T->getCoupledDecls());
+      *ToWrappedTypeOrErr, CountExpr, T->isCountInBytes(),
+      T->isOrNull(), ArrayRef(CoupledDecls.data(), CoupledDecls.size()));
 }
 
 ExpectedType ASTNodeImporter::VisitTemplateTypeParmType(

>From d76a7f608b8d07b2bdaaf0ff6e74bf6432f2f9a4 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Thu, 1 Feb 2024 10:30:35 -0800
Subject: [PATCH 11/12] clang-format

---
 clang/lib/AST/ASTImporter.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index d05347e72aeb5..7c315ec20adae 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1545,8 +1545,8 @@ ASTNodeImporter::VisitCountAttributedType(const CountAttributedType *T) {
   }
 
   return Importer.getToContext().getCountAttributedType(
-      *ToWrappedTypeOrErr, CountExpr, T->isCountInBytes(),
-      T->isOrNull(), ArrayRef(CoupledDecls.data(), CoupledDecls.size()));
+      *ToWrappedTypeOrErr, CountExpr, T->isCountInBytes(), T->isOrNull(),
+      ArrayRef(CoupledDecls.data(), CoupledDecls.size()));
 }
 
 ExpectedType ASTNodeImporter::VisitTemplateTypeParmType(

>From 5fd456d316077461ac24e4cbda9514a66dff4be7 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Thu, 1 Feb 2024 11:09:34 -0800
Subject: [PATCH 12/12] Print counted_by attribute after the array type

---
 clang/lib/AST/TypePrinter.cpp | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 4a01d7ae1fa96..958d639ddef87 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -1715,10 +1715,8 @@ void TypePrinter::printPackExpansionAfter(const PackExpansionType *T,
   OS << "...";
 }
 
-void TypePrinter::printCountAttributedBefore(const CountAttributedType *T,
-                                             raw_ostream &OS) {
-  printBefore(T->desugar(), OS);
-
+static void printCountAttributedImpl(const CountAttributedType *T,
+                                     raw_ostream &OS, PrintingPolicy Policy) {
   if (T->isCountInBytes() && T->isOrNull())
     OS << " __sized_by_or_null(";
   else if (T->isCountInBytes())
@@ -1732,9 +1730,18 @@ void TypePrinter::printCountAttributedBefore(const CountAttributedType *T,
   OS << ')';
 }
 
+void TypePrinter::printCountAttributedBefore(const CountAttributedType *T,
+                                             raw_ostream &OS) {
+  printBefore(T->desugar(), OS);
+  if (!T->desugar()->isArrayType())
+    printCountAttributedImpl(T, OS, Policy);
+}
+
 void TypePrinter::printCountAttributedAfter(const CountAttributedType *T,
                                             raw_ostream &OS) {
   printAfter(T->desugar(), OS);
+  if (T->desugar()->isArrayType())
+    printCountAttributedImpl(T, OS, Policy);
 }
 
 void TypePrinter::printAttributedBefore(const AttributedType *T,



More information about the llvm-commits mailing list