[clang] Turn 'counted_by' into a type attribute and parse it into 'CountAttributedType' (PR #78000)

Yeoul Na via cfe-commits cfe-commits at lists.llvm.org
Thu Mar 7 10:05:11 PST 2024


https://github.com/rapidsna updated https://github.com/llvm/llvm-project/pull/78000

>From 3d2716ad6088f600c14d6ff724aa90453130a1f7 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Mon, 18 Dec 2023 10:58:16 +0900
Subject: [PATCH 1/9] [BoundsSafety] Introduce CountAttributedType

CountAttributedType is a sugar type to represent a type with
a 'counted_by' attribute and the likes, which provides bounds
information to the underlying type. The type contains an
the argument of attribute as an expression. Additionally, the type
holds metadata about declarations referenced by the expression in
order to make it easier for Sema to access declarations on which
the type depends.

This also adjusts the CountedBy attribute definition and implements
parsing CountAttributedType.

__bdos and array-checks sanitizer use CountAttributedType instead
of hasAttr<CountedByAttr>.

Implements special lookup for counted_by argument in structs.

Adjust test/Sema/attr-counted-by.c to match the default diags
generated by the expression parser.
---
 clang/include/clang/AST/ASTContext.h          |   7 +
 clang/include/clang/AST/PropertiesBase.td     |   1 +
 clang/include/clang/AST/RecursiveASTVisitor.h |   9 +
 clang/include/clang/AST/Type.h                | 152 +++++++++++++++
 clang/include/clang/AST/TypeLoc.h             |  26 +++
 clang/include/clang/AST/TypeProperties.td     |  19 ++
 clang/include/clang/Basic/Attr.td             |   7 +-
 .../clang/Basic/DiagnosticSemaKinds.td        |  14 +-
 clang/include/clang/Basic/TypeNodes.td        |   2 +
 clang/include/clang/Parse/Parser.h            |   5 +
 .../clang/Serialization/ASTRecordReader.h     |   2 +
 .../clang/Serialization/ASTRecordWriter.h     |   5 +
 .../clang/Serialization/TypeBitCodes.def      |   2 +-
 clang/lib/AST/ASTContext.cpp                  |  55 ++++++
 clang/lib/AST/ASTImporter.cpp                 |  22 +++
 clang/lib/AST/ASTStructuralEquivalence.cpp    |   7 +
 clang/lib/AST/ItaniumMangle.cpp               |   1 +
 clang/lib/AST/Type.cpp                        |  65 +++++++
 clang/lib/AST/TypeLoc.cpp                     |   4 +
 clang/lib/AST/TypePrinter.cpp                 |  30 +++
 clang/lib/CodeGen/CGBuiltin.cpp               |   6 +-
 clang/lib/CodeGen/CGDebugInfo.cpp             |   1 +
 clang/lib/CodeGen/CGExpr.cpp                  |  37 +---
 clang/lib/CodeGen/CodeGenFunction.cpp         |   1 +
 clang/lib/Parse/ParseDecl.cpp                 |  90 +++++++++
 clang/lib/Sema/SemaDecl.cpp                   |   6 -
 clang/lib/Sema/SemaDeclAttr.cpp               | 183 ++++++++----------
 clang/lib/Sema/SemaExpr.cpp                   |  25 ++-
 clang/lib/Sema/SemaType.cpp                   |  12 ++
 clang/lib/Sema/TreeTransform.h                |   7 +
 clang/lib/Serialization/ASTReader.cpp         |   8 +
 clang/lib/Serialization/ASTWriter.cpp         |   4 +
 clang/test/Sema/attr-counted-by.c             |  20 +-
 clang/tools/libclang/CIndex.cpp               |   4 +
 34 files changed, 679 insertions(+), 160 deletions(-)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index ff6b64c7f72d57..002f36ecbbaa3f 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -250,6 +250,8 @@ class ASTContext : public RefCountedBase<ASTContext> {
       DependentBitIntTypes;
   llvm::FoldingSet<BTFTagAttributedType> BTFTagAttributedTypes;
 
+  mutable llvm::FoldingSet<CountAttributedType> CountAttributedTypes;
+
   mutable llvm::FoldingSet<QualifiedTemplateName> QualifiedTemplateNames;
   mutable llvm::FoldingSet<DependentTemplateName> DependentTemplateNames;
   mutable llvm::FoldingSet<SubstTemplateTemplateParmStorage>
@@ -1341,6 +1343,11 @@ class ASTContext : public RefCountedBase<ASTContext> {
     return CanQualType::CreateUnsafe(getPointerType((QualType) T));
   }
 
+  QualType
+  getCountAttributedType(QualType T, Expr *CountExpr, bool CountInBytes,
+                         bool OrNull,
+                         ArrayRef<TypeCoupledDeclRefInfo> DependentDecls) const;
+
   /// Return the uniqued reference to a type adjusted from the original
   /// type to a new type.
   QualType getAdjustedType(QualType Orig, QualType New) const;
diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td
index 0270c086d06b6a..6df1d93a7ba2eb 100644
--- a/clang/include/clang/AST/PropertiesBase.td
+++ b/clang/include/clang/AST/PropertiesBase.td
@@ -143,6 +143,7 @@ def UInt32 : CountPropertyType<"uint32_t">;
 def UInt64 : CountPropertyType<"uint64_t">;
 def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">;
 def VectorKind : EnumPropertyType<"VectorKind">;
+def TypeCoupledDeclRefInfo : PropertyType;
 
 def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> {
   let BufferElementTypes = [ QualType ];
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 5080551ada4fc6..4a1ff222ecadcd 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -1110,6 +1110,12 @@ DEF_TRAVERSE_TYPE(InjectedClassNameType, {})
 DEF_TRAVERSE_TYPE(AttributedType,
                   { TRY_TO(TraverseType(T->getModifiedType())); })
 
+DEF_TRAVERSE_TYPE(CountAttributedType, {
+  if (T->getCountExpr())
+    TRY_TO(TraverseStmt(T->getCountExpr()));
+  TRY_TO(TraverseType(T->desugar()));
+})
+
 DEF_TRAVERSE_TYPE(BTFTagAttributedType,
                   { TRY_TO(TraverseType(T->getWrappedType())); })
 
@@ -1401,6 +1407,9 @@ DEF_TRAVERSE_TYPELOC(MacroQualifiedType,
 DEF_TRAVERSE_TYPELOC(AttributedType,
                      { TRY_TO(TraverseTypeLoc(TL.getModifiedLoc())); })
 
+DEF_TRAVERSE_TYPELOC(CountAttributedType,
+                     { TRY_TO(TraverseTypeLoc(TL.getInnerLoc())); })
+
 DEF_TRAVERSE_TYPELOC(BTFTagAttributedType,
                      { TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); })
 
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 1942b0e67f65a3..3e336f524ca4f6 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -61,6 +61,7 @@ class BTFTypeTagAttr;
 class ExtQuals;
 class QualType;
 class ConceptDecl;
+class ValueDecl;
 class TagDecl;
 class TemplateParameterList;
 class Type;
@@ -2000,6 +2001,21 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     unsigned NumExpansions;
   };
 
+  class CountAttributedTypeBitfields {
+    friend class CountAttributedType;
+
+    LLVM_PREFERRED_TYPE(TypeBitfields)
+    unsigned : NumTypeBits;
+
+    /// The limit is 15.
+    static constexpr unsigned NumCoupledDeclsBits = 4;
+    unsigned NumCoupledDecls : NumCoupledDeclsBits;
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned CountInBytes : 1;
+    LLVM_PREFERRED_TYPE(bool)
+    unsigned OrNull : 1;
+  };
+
   union {
     TypeBitfields TypeBits;
     ArrayTypeBitfields ArrayTypeBits;
@@ -2022,6 +2038,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     DependentTemplateSpecializationTypeBitfields
       DependentTemplateSpecializationTypeBits;
     PackExpansionTypeBitfields PackExpansionTypeBits;
+    CountAttributedTypeBitfields CountAttributedTypeBits;
   };
 
 private:
@@ -2722,6 +2739,14 @@ template <> const TemplateSpecializationType *Type::getAs() const;
 /// until it reaches an AttributedType or a non-sugared type.
 template <> const AttributedType *Type::getAs() const;
 
+/// This will check for a BoundsAttributedType by removing any existing
+/// sugar until it reaches an BoundsAttributedType or a non-sugared type.
+template <> const BoundsAttributedType *Type::getAs() const;
+
+/// This will check for a CountAttributedType by removing any existing
+/// sugar until it reaches an CountAttributedType or a non-sugared type.
+template <> const CountAttributedType *Type::getAs() const;
+
 // We can do canonical leaf types faster, because we don't have to
 // worry about preserving child type decoration.
 #define TYPE(Class, Base)
@@ -2920,6 +2945,133 @@ class PointerType : public Type, public llvm::FoldingSetNode {
   static bool classof(const Type *T) { return T->getTypeClass() == Pointer; }
 };
 
+/// [BoundsSafety] Represents information of declarations referenced by the
+/// arguments of the `counted_by` attribute and the likes.
+class TypeCoupledDeclRefInfo {
+public:
+  using BaseTy = llvm::PointerIntPair<ValueDecl *, 1, unsigned>;
+
+private:
+  enum {
+    DerefShift = 0,
+    DerefMask = 1,
+  };
+  BaseTy Data;
+
+public:
+  /// \p D is to a declaration referenced by the argument of attribute. \p Deref
+  /// indicates whether \p D is referenced as a dereferenced form, e.g., \p
+  /// Deref is true for `*n` in `int *__counted_by(*n)`.
+  TypeCoupledDeclRefInfo(ValueDecl *D = nullptr, bool Deref = false);
+
+  bool isDeref() const;
+  ValueDecl *getDecl() const;
+  unsigned getInt() const;
+  void *getOpaqueValue() const;
+  bool operator==(const TypeCoupledDeclRefInfo &Other) const;
+  void setFromOpaqueValue(void *V);
+};
+
+/// [BoundsSafety] Represents a parent type class for CountAttributedType and
+/// similar sugar types that will be introduced to represent a type with a
+/// bounds attribute.
+///
+/// Provides a common interface to navigate declarations referred to by the
+/// bounds expression.
+
+class BoundsAttributedType : public Type, public llvm::FoldingSetNode {
+  QualType WrappedTy;
+
+protected:
+  ArrayRef<TypeCoupledDeclRefInfo> Decls; // stored in trailing objects
+
+  BoundsAttributedType(TypeClass TC, QualType Wrapped, QualType Canon);
+
+public:
+  bool isSugared() const { return true; }
+  QualType desugar() const { return WrappedTy; }
+
+  using decl_iterator = const TypeCoupledDeclRefInfo *;
+  using decl_range = llvm::iterator_range<decl_iterator>;
+
+  decl_iterator dependent_decl_begin() const { return Decls.begin(); }
+  decl_iterator dependent_decl_end() const { return Decls.end(); }
+
+  unsigned getNumCoupledDecls() const { return Decls.size(); }
+
+  decl_range dependent_decls() const {
+    return decl_range(dependent_decl_begin(), dependent_decl_end());
+  }
+
+  ArrayRef<TypeCoupledDeclRefInfo> getCoupledDecls() const {
+    return {dependent_decl_begin(), dependent_decl_end()};
+  }
+
+  bool referencesFieldDecls() const;
+
+  static bool classof(const Type *T) {
+    switch (T->getTypeClass()) {
+    case CountAttributed:
+      return true;
+    default:
+      return false;
+    }
+  }
+};
+
+/// Represents a sugar type with `__counted_by` or `__sized_by` annotations,
+/// including their `_or_null` variants.
+class CountAttributedType final
+    : public BoundsAttributedType,
+      public llvm::TrailingObjects<CountAttributedType,
+                                   TypeCoupledDeclRefInfo> {
+  friend class ASTContext;
+
+  Expr *CountExpr;
+  /// \p CountExpr represents the argument of __counted_by or the likes. \p
+  /// CountInBytes indicates that \p CountExpr is a byte count (i.e.,
+  /// __sized_by(_or_null)) \p OrNull means it's an or_null variant (i.e.,
+  /// __counted_by_or_null or __sized_by_or_null) \p CoupledDecls contains the
+  /// list of declarations referenced by \p CountExpr, which the type depends on
+  /// for the bounds information.
+  CountAttributedType(QualType Wrapped, QualType Canon, Expr *CountExpr,
+                      bool CountInBytes, bool OrNull,
+                      ArrayRef<TypeCoupledDeclRefInfo> CoupledDecls);
+
+  unsigned numTrailingObjects(OverloadToken<TypeCoupledDeclRefInfo>) const {
+    return CountAttributedTypeBits.NumCoupledDecls;
+  }
+
+public:
+  enum DynamicCountPointerKind {
+    CountedBy = 0,
+    SizedBy,
+    CountedByOrNull,
+    SizedByOrNull,
+  };
+
+  Expr *getCountExpr() const { return CountExpr; }
+  bool isCountInBytes() const { return CountAttributedTypeBits.CountInBytes; }
+  bool isOrNull() const { return CountAttributedTypeBits.OrNull; }
+
+  DynamicCountPointerKind getKind() const {
+    if (isOrNull())
+      return isCountInBytes() ? SizedByOrNull : CountedByOrNull;
+    return isCountInBytes() ? SizedBy : CountedBy;
+  }
+
+  void Profile(llvm::FoldingSetNodeID &ID) {
+    Profile(ID, desugar(), CountExpr, isCountInBytes(), isOrNull());
+  }
+
+  static void Profile(llvm::FoldingSetNodeID &ID, QualType WrappedTy,
+                      Expr *CountExpr, bool CountInBytes, bool Nullable);
+
+  static bool classof(const Type *T) {
+    return T->getTypeClass() == CountAttributed;
+  }
+};
+
 /// Represents a type which was implicitly adjusted by the semantic
 /// engine for arbitrary reasons.  For example, array and function types can
 /// decay, and function types can have their calling conventions adjusted.
diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h
index 32028e877fc8dc..6415a3bb0b4e4f 100644
--- a/clang/include/clang/AST/TypeLoc.h
+++ b/clang/include/clang/AST/TypeLoc.h
@@ -1120,6 +1120,32 @@ class ObjCInterfaceTypeLoc : public ConcreteTypeLoc<ObjCObjectTypeLoc,
   }
 };
 
+struct BoundsAttributedLocInfo {};
+class BoundsAttributedTypeLoc
+    : public ConcreteTypeLoc<UnqualTypeLoc, BoundsAttributedTypeLoc,
+                             BoundsAttributedType, BoundsAttributedLocInfo> {
+public:
+  TypeLoc getInnerLoc() const { return this->getInnerTypeLoc(); }
+  QualType getInnerType() const { return this->getTypePtr()->desugar(); }
+  void initializeLocal(ASTContext &Context, SourceLocation Loc) {
+    // nothing to do
+  }
+  unsigned getLocalDataAlignment() const { return 1; }
+  unsigned getLocalDataSize() const { return 0; }
+};
+
+class CountAttributedTypeLoc final
+    : public InheritingConcreteTypeLoc<BoundsAttributedTypeLoc,
+                                       CountAttributedTypeLoc,
+                                       CountAttributedType> {
+public:
+  Expr *getCountExpr() const { return getTypePtr()->getCountExpr(); }
+  bool isCountInBytes() const { return getTypePtr()->isCountInBytes(); }
+  bool isOrNull() const { return getTypePtr()->isOrNull(); }
+
+  SourceRange getLocalSourceRange() const;
+};
+
 struct MacroQualifiedLocInfo {
   SourceLocation ExpansionLoc;
 };
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index 0ba172a4035fdb..f3906cb7f86fbf 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -25,6 +25,25 @@ let Class = PointerType in {
   def : Creator<[{ return ctx.getPointerType(pointeeType); }]>;
 }
 
+let Class = CountAttributedType in {
+  def : Property<"WrappedTy", QualType> {
+    let Read = [{ node->desugar() }];
+  }
+  def : Property<"CountExpr", ExprRef> {
+    let Read = [{ node->getCountExpr() }];
+  }
+  def : Property<"CountInBytes", Bool> {
+    let Read = [{ node->isCountInBytes() }];
+  }
+  def : Property<"OrNull", Bool> {
+    let Read = [{ node->isOrNull() }];
+  }
+  def : Property<"CoupledDecls", Array<TypeCoupledDeclRefInfo>> {
+    let Read = [{ node->getCoupledDecls() }];
+  }
+  def : Creator<[{ return ctx.getCountAttributedType(WrappedTy, CountExpr, CountInBytes, OrNull, CoupledDecls); }]>;
+}
+
 let Class = AdjustedType in {
   def : Property<"originalType", QualType> {
     let Read = [{ node->getOriginalType() }];
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index fa191c7378dba4..d7933001516ebe 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -4466,10 +4466,11 @@ def CodeAlign: StmtAttr {
   }];
 }
 
-def CountedBy : InheritableAttr {
+def CountedBy : DeclOrTypeAttr {
   let Spellings = [Clang<"counted_by">];
-  let Subjects = SubjectList<[Field]>;
-  let Args = [IdentifierArgument<"CountedByField">];
+  let Subjects = SubjectList<[Field], ErrorDiag>;
+  let Args = [ExprArgument<"Count">, IntArgument<"NestedLevel">];
+  let ParseArgumentsAsUnevaluated = 1;
   let Documentation = [CountedByDocs];
   let LangOpts = [COnly];
   // FIXME: This is ugly. Let using a DeclArgument would be nice, but a Decl
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index c8dfdc08f5ea07..2d8f904eb1efac 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -6502,12 +6502,18 @@ def err_flexible_array_count_not_in_same_struct : Error<
   "'counted_by' field %0 isn't within the same struct as the flexible array">;
 def err_counted_by_attr_not_on_flexible_array_member : Error<
   "'counted_by' only applies to C99 flexible array members">;
-def err_counted_by_attr_refers_to_flexible_array : Error<
-  "'counted_by' cannot refer to the flexible array %0">;
+def err_counted_by_attr_refer_to_itself : Error<
+  "'counted_by' cannot refer to the flexible array member %0">;
 def err_counted_by_must_be_in_structure : Error<
   "field %0 in 'counted_by' not inside structure">;
-def err_flexible_array_counted_by_attr_field_not_integer : Error<
-  "field %0 in 'counted_by' must be a non-boolean integer type">;
+def err_counted_by_attr_argument_not_integer : Error<
+  "'counted_by' requires a non-boolean integer type argument">;
+def err_counted_by_attr_only_support_simple_decl_reference : Error<
+  "'counted_by' argument must be a simple declaration reference">;
+def err_counted_by_attr_in_union : Error<
+  "'counted_by' cannot be applied to a union member">;
+def err_counted_by_attr_refer_to_union : Error<
+  "'counted_by' argument cannot refer to a union member">;
 def note_flexible_array_counted_by_attr_field : Note<
   "field %0 declared here">;
 
diff --git a/clang/include/clang/Basic/TypeNodes.td b/clang/include/clang/Basic/TypeNodes.td
index 6419762417371b..3625f063758915 100644
--- a/clang/include/clang/Basic/TypeNodes.td
+++ b/clang/include/clang/Basic/TypeNodes.td
@@ -108,6 +108,8 @@ def ObjCTypeParamType : TypeNode<Type>, NeverCanonical;
 def ObjCObjectType : TypeNode<Type>;
 def ObjCInterfaceType : TypeNode<ObjCObjectType>, LeafType;
 def ObjCObjectPointerType : TypeNode<Type>;
+def BoundsAttributedType : TypeNode<Type, 1>;
+def CountAttributedType : TypeNode<BoundsAttributedType>, NeverCanonical;
 def PipeType : TypeNode<Type>;
 def AtomicType : TypeNode<Type>;
 def BitIntType : TypeNode<Type>;
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index 071520f535bc95..b84fe7714218a1 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3066,6 +3066,11 @@ class Parser : public CodeCompletionHandler {
                                  SourceLocation ScopeLoc,
                                  ParsedAttr::Form Form);
 
+  void ParseBoundsAttribute(IdentifierInfo &AttrName,
+                            SourceLocation AttrNameLoc, ParsedAttributes &Attrs,
+                            IdentifierInfo *ScopeName, SourceLocation ScopeLoc,
+                            ParsedAttr::Form Form);
+
   void ParseTypeofSpecifier(DeclSpec &DS);
   SourceLocation ParseDecltypeSpecifier(DeclSpec &DS);
   void AnnotateExistingDecltypeSpecifier(const DeclSpec &DS,
diff --git a/clang/include/clang/Serialization/ASTRecordReader.h b/clang/include/clang/Serialization/ASTRecordReader.h
index 80a1359fd16aa0..5d3e95cb5d630f 100644
--- a/clang/include/clang/Serialization/ASTRecordReader.h
+++ b/clang/include/clang/Serialization/ASTRecordReader.h
@@ -221,6 +221,8 @@ class ASTRecordReader
     return Reader->ReadSelector(*F, Record, Idx);
   }
 
+  TypeCoupledDeclRefInfo readTypeCoupledDeclRefInfo();
+
   /// Read a declaration name, advancing Idx.
   // DeclarationName readDeclarationName(); (inherited)
   DeclarationNameLoc readDeclarationNameLoc(DeclarationName Name);
diff --git a/clang/include/clang/Serialization/ASTRecordWriter.h b/clang/include/clang/Serialization/ASTRecordWriter.h
index 9a64735c9fa55d..e007d4a70843a1 100644
--- a/clang/include/clang/Serialization/ASTRecordWriter.h
+++ b/clang/include/clang/Serialization/ASTRecordWriter.h
@@ -141,6 +141,11 @@ class ASTRecordWriter
     AddSourceLocation(Loc);
   }
 
+  void writeTypeCoupledDeclRefInfo(TypeCoupledDeclRefInfo Info) {
+    writeDeclRef(Info.getDecl());
+    writeBool(Info.isDeref());
+  }
+
   /// Emit a source range.
   void AddSourceRange(SourceRange Range, LocSeq *Seq = nullptr) {
     return Writer->AddSourceRange(Range, *Record, Seq);
diff --git a/clang/include/clang/Serialization/TypeBitCodes.def b/clang/include/clang/Serialization/TypeBitCodes.def
index adbd5ec5d4b544..3c82dfed9497d5 100644
--- a/clang/include/clang/Serialization/TypeBitCodes.def
+++ b/clang/include/clang/Serialization/TypeBitCodes.def
@@ -65,6 +65,6 @@ TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 53)
 TYPE_BIT_CODE(Using, USING, 54)
 TYPE_BIT_CODE(BTFTagAttributed, BTFTAG_ATTRIBUTED, 55)
 TYPE_BIT_CODE(PackIndexing, PACK_INDEXING, 56)
-
+TYPE_BIT_CODE(CountAttributed, COUNT_ATTRIBUTED, 57)
 
 #undef TYPE_BIT_CODE
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 5a8fae76a43a4d..5247e29ea30204 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -2348,6 +2348,9 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
     return getTypeInfo(
                   cast<AttributedType>(T)->getEquivalentType().getTypePtr());
 
+  case Type::CountAttributed:
+    return getTypeInfo(cast<CountAttributedType>(T)->desugar().getTypePtr());
+
   case Type::BTFTagAttributed:
     return getTypeInfo(
         cast<BTFTagAttributedType>(T)->getWrappedType().getTypePtr());
@@ -3122,6 +3125,32 @@ QualType ASTContext::removePtrSizeAddrSpace(QualType T) const {
   return T;
 }
 
+QualType ASTContext::getCountAttributedType(
+    QualType WrappedTy, Expr *CountExpr, bool CountInBytes, bool OrNull,
+    ArrayRef<TypeCoupledDeclRefInfo> DependentDecls) const {
+  assert(WrappedTy->isPointerType() || WrappedTy->isArrayType());
+
+  llvm::FoldingSetNodeID ID;
+  CountAttributedType::Profile(ID, WrappedTy, CountExpr, CountInBytes, OrNull);
+
+  void *InsertPos = nullptr;
+  CountAttributedType *CATy =
+      CountAttributedTypes.FindNodeOrInsertPos(ID, InsertPos);
+  if (CATy)
+    return QualType(CATy, 0);
+
+  QualType CanonTy = getCanonicalType(WrappedTy);
+  size_t Size = CountAttributedType::totalSizeToAlloc<TypeCoupledDeclRefInfo>(
+      DependentDecls.size());
+  CATy = (CountAttributedType *)Allocate(Size, TypeAlignment);
+  new (CATy) CountAttributedType(WrappedTy, CanonTy, CountExpr, CountInBytes,
+                                 OrNull, DependentDecls);
+  Types.push_back(CATy);
+  CountAttributedTypes.InsertNode(CATy, InsertPos);
+
+  return QualType(CATy, 0);
+}
+
 const FunctionType *ASTContext::adjustFunctionType(const FunctionType *T,
                                                    FunctionType::ExtInfo Info) {
   if (T->getExtInfo() == Info)
@@ -13234,6 +13263,32 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X,
       return QualType();
     return Ctx.getUsingType(CD, Ctx.getQualifiedType(Underlying));
   }
+  case Type::CountAttributed: {
+    const auto *DX = cast<CountAttributedType>(X),
+               *DY = cast<CountAttributedType>(Y);
+    if (DX->isCountInBytes() != DY->isCountInBytes())
+      return QualType();
+    if (DX->isOrNull() != DY->isOrNull())
+      return QualType();
+    const auto CEX = DX->getCountExpr();
+    const auto CEY = DY->getCountExpr();
+    const auto CDX = DX->getCoupledDecls();
+    if (Ctx.hasSameExpr(CEX, CEY))
+      return Ctx.getCountAttributedType(Ctx.getQualifiedType(Underlying), CEX,
+                                        DX->isCountInBytes(), DX->isOrNull(),
+                                        CDX);
+    if (!CEX->isIntegerConstantExpr(Ctx) || !CEY->isIntegerConstantExpr(Ctx))
+      return QualType();
+    // Two declarations with the same integer constant may still differ in their
+    // expression pointers, so we need to evaluate them.
+    const auto VX = CEX->getIntegerConstantExpr(Ctx);
+    const auto VY = CEY->getIntegerConstantExpr(Ctx);
+    if (*VX != *VY)
+      return QualType();
+    return Ctx.getCountAttributedType(Ctx.getQualifiedType(Underlying), CEX,
+                                      DX->isCountInBytes(), DX->isOrNull(),
+                                      CDX);
+  }
   }
   llvm_unreachable("Unhandled Type Class");
 }
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index 023aaa7f0572b4..7c315ec20adae3 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1527,6 +1527,28 @@ ExpectedType ASTNodeImporter::VisitAttributedType(const AttributedType *T) {
       *ToModifiedTypeOrErr, *ToEquivalentTypeOrErr);
 }
 
+ExpectedType
+ASTNodeImporter::VisitCountAttributedType(const CountAttributedType *T) {
+  ExpectedType ToWrappedTypeOrErr = import(T->desugar());
+  if (!ToWrappedTypeOrErr)
+    return ToWrappedTypeOrErr.takeError();
+
+  Error Err = Error::success();
+  Expr *CountExpr = importChecked(Err, T->getCountExpr());
+
+  SmallVector<TypeCoupledDeclRefInfo, 1> CoupledDecls;
+  for (auto TI : T->dependent_decls()) {
+    Expected<ValueDecl *> ToDeclOrErr = import(TI.getDecl());
+    if (!ToDeclOrErr)
+      return ToDeclOrErr.takeError();
+    CoupledDecls.emplace_back(*ToDeclOrErr, TI.isDeref());
+  }
+
+  return Importer.getToContext().getCountAttributedType(
+      *ToWrappedTypeOrErr, CountExpr, T->isCountInBytes(), T->isOrNull(),
+      ArrayRef(CoupledDecls.data(), CoupledDecls.size()));
+}
+
 ExpectedType ASTNodeImporter::VisitTemplateTypeParmType(
     const TemplateTypeParmType *T) {
   Expected<TemplateTypeParmDecl *> ToDeclOrErr = import(T->getDecl());
diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp
index fe6e03ce174e59..226e0aa38ece70 100644
--- a/clang/lib/AST/ASTStructuralEquivalence.cpp
+++ b/clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -1069,6 +1069,13 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
       return false;
     break;
 
+  case Type::CountAttributed:
+    if (!IsStructurallyEquivalent(Context,
+                                  cast<CountAttributedType>(T1)->desugar(),
+                                  cast<CountAttributedType>(T2)->desugar()))
+      return false;
+    break;
+
   case Type::BTFTagAttributed:
     if (!IsStructurallyEquivalent(
             Context, cast<BTFTagAttributedType>(T1)->getWrappedType(),
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 1b6141580a81d2..f619d657ae9f50 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -2431,6 +2431,7 @@ bool CXXNameMangler::mangleUnresolvedTypeOrSimpleId(QualType Ty,
   case Type::MacroQualified:
   case Type::BitInt:
   case Type::DependentBitInt:
+  case Type::CountAttributed:
     llvm_unreachable("type is illegal as a nested name specifier");
 
   case Type::SubstTemplateTypeParmPack:
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 78dcd3f4007a5a..89947f0e901323 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -385,6 +385,26 @@ void DependentBitIntType::Profile(llvm::FoldingSetNodeID &ID,
   NumBitsExpr->Profile(ID, Context, true);
 }
 
+bool BoundsAttributedType::referencesFieldDecls() const {
+  for (const auto &Decl : dependent_decls())
+    if (isa<FieldDecl>(Decl.getDecl()))
+      return true;
+  return false;
+}
+
+void CountAttributedType::Profile(llvm::FoldingSetNodeID &ID,
+                                  QualType WrappedTy, Expr *CountExpr,
+                                  bool CountInBytes, bool OrNull) {
+  ID.AddPointer(WrappedTy.getAsOpaquePtr());
+  ID.AddBoolean(CountInBytes);
+  ID.AddBoolean(OrNull);
+  // We profile it as a pointer as the StmtProfiler considers parameter
+  // expressions on function declaration and function definition as the
+  // same, resulting in count expression being evaluated with ParamDecl
+  // not in the function scope.
+  ID.AddPointer(CountExpr);
+}
+
 /// getArrayElementTypeNoTypeQual - If this is an array type, return the
 /// element type of the array, potentially with type qualifiers missing.
 /// This method should never be used when type qualifiers are meaningful.
@@ -559,6 +579,14 @@ template <> const AttributedType *Type::getAs() const {
   return getAsSugar<AttributedType>(this);
 }
 
+template <> const BoundsAttributedType *Type::getAs() const {
+  return getAsSugar<BoundsAttributedType>(this);
+}
+
+template <> const CountAttributedType *Type::getAs() const {
+  return getAsSugar<CountAttributedType>(this);
+}
+
 /// getUnqualifiedDesugaredType - Pull any qualifiers and syntactic
 /// sugar off the given type.  This should produce an object of the
 /// same dynamic type as the canonical type.
@@ -3707,6 +3735,43 @@ void FunctionProtoType::Profile(llvm::FoldingSetNodeID &ID,
           getExtProtoInfo(), Ctx, isCanonicalUnqualified());
 }
 
+TypeCoupledDeclRefInfo::TypeCoupledDeclRefInfo(ValueDecl *D, bool Deref)
+    : Data(D, Deref << DerefShift) {}
+
+bool TypeCoupledDeclRefInfo::isDeref() const {
+  return Data.getInt() & DerefMask;
+}
+ValueDecl *TypeCoupledDeclRefInfo::getDecl() const { return Data.getPointer(); }
+unsigned TypeCoupledDeclRefInfo::getInt() const { return Data.getInt(); }
+void *TypeCoupledDeclRefInfo::getOpaqueValue() const {
+  return Data.getOpaqueValue();
+}
+bool TypeCoupledDeclRefInfo::operator==(
+    const TypeCoupledDeclRefInfo &Other) const {
+  return getOpaqueValue() == Other.getOpaqueValue();
+}
+void TypeCoupledDeclRefInfo::setFromOpaqueValue(void *V) {
+  Data.setFromOpaqueValue(V);
+}
+
+BoundsAttributedType::BoundsAttributedType(TypeClass TC, QualType Wrapped,
+                                           QualType Canon)
+    : Type(TC, Canon, Wrapped->getDependence()), WrappedTy(Wrapped) {}
+
+CountAttributedType::CountAttributedType(
+    QualType Wrapped, QualType Canon, Expr *CountExpr, bool CountInBytes,
+    bool OrNull, ArrayRef<TypeCoupledDeclRefInfo> CoupledDecls)
+    : BoundsAttributedType(CountAttributed, Wrapped, Canon),
+      CountExpr(CountExpr) {
+  CountAttributedTypeBits.NumCoupledDecls = CoupledDecls.size();
+  CountAttributedTypeBits.CountInBytes = CountInBytes;
+  CountAttributedTypeBits.OrNull = OrNull;
+  auto *DeclSlot = getTrailingObjects<TypeCoupledDeclRefInfo>();
+  Decls = llvm::ArrayRef(DeclSlot, CoupledDecls.size());
+  for (unsigned i = 0; i != CoupledDecls.size(); ++i)
+    DeclSlot[i] = CoupledDecls[i];
+}
+
 TypedefType::TypedefType(TypeClass tc, const TypedefNameDecl *D,
                          QualType Underlying, QualType can)
     : Type(tc, can, toSemanticDependence(can->getDependence())),
diff --git a/clang/lib/AST/TypeLoc.cpp b/clang/lib/AST/TypeLoc.cpp
index b0acb182308755..21e152f6aea8a0 100644
--- a/clang/lib/AST/TypeLoc.cpp
+++ b/clang/lib/AST/TypeLoc.cpp
@@ -516,6 +516,10 @@ SourceRange AttributedTypeLoc::getLocalSourceRange() const {
   return getAttr() ? getAttr()->getRange() : SourceRange();
 }
 
+SourceRange CountAttributedTypeLoc::getLocalSourceRange() const {
+  return getCountExpr() ? getCountExpr()->getSourceRange() : SourceRange();
+}
+
 SourceRange BTFTagAttributedTypeLoc::getLocalSourceRange() const {
   return getAttr() ? getAttr()->getRange() : SourceRange();
 }
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 7dcc4348f8e036..2c61fa05506434 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -286,6 +286,7 @@ bool TypePrinter::canPrefixQualifiers(const Type *T,
     case Type::PackExpansion:
     case Type::SubstTemplateTypeParm:
     case Type::MacroQualified:
+    case Type::CountAttributed:
       CanPrefixQualifiers = false;
       break;
 
@@ -1717,6 +1718,35 @@ void TypePrinter::printPackExpansionAfter(const PackExpansionType *T,
   OS << "...";
 }
 
+static void printCountAttributedImpl(const CountAttributedType *T,
+                                     raw_ostream &OS, PrintingPolicy Policy) {
+  if (T->isCountInBytes() && T->isOrNull())
+    OS << " __sized_by_or_null(";
+  else if (T->isCountInBytes())
+    OS << " __sized_by(";
+  else if (T->isOrNull())
+    OS << " __counted_by_or_null(";
+  else
+    OS << " __counted_by(";
+  if (T->getCountExpr())
+    T->getCountExpr()->printPretty(OS, nullptr, Policy);
+  OS << ')';
+}
+
+void TypePrinter::printCountAttributedBefore(const CountAttributedType *T,
+                                             raw_ostream &OS) {
+  printBefore(T->desugar(), OS);
+  if (!T->desugar()->isArrayType())
+    printCountAttributedImpl(T, OS, Policy);
+}
+
+void TypePrinter::printCountAttributedAfter(const CountAttributedType *T,
+                                            raw_ostream &OS) {
+  printAfter(T->desugar(), OS);
+  if (T->desugar()->isArrayType())
+    printCountAttributedImpl(T, OS, Policy);
+}
+
 void TypePrinter::printAttributedBefore(const AttributedType *T,
                                         raw_ostream &OS) {
   // FIXME: Generate this with TableGen.
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 20c35757939152..8a2810fea21054 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -858,7 +858,7 @@ static unsigned CountCountedByAttrs(const RecordDecl *RD) {
 
   for (const Decl *D : RD->decls()) {
     if (const auto *FD = dyn_cast<FieldDecl>(D);
-        FD && FD->hasAttr<CountedByAttr>()) {
+        FD && FD->getType()->getAs<CountAttributedType>()) {
       return ++Num;
     }
 
@@ -956,7 +956,7 @@ CodeGenFunction::emitFlexibleArrayMemberSize(const Expr *E, unsigned Type,
     //         };
     //    };
     //
-    // We don't konw which 'count' to use in this scenario:
+    // We don't know which 'count' to use in this scenario:
     //
     //     size_t get_size(struct union_of_fams *p) {
     //         return __builtin_dynamic_object_size(p, 1);
@@ -975,7 +975,7 @@ CodeGenFunction::emitFlexibleArrayMemberSize(const Expr *E, unsigned Type,
       FindFlexibleArrayMemberField(Ctx, OuterRD, FAMName, Offset);
   Offset = Ctx.toCharUnitsFromBits(Offset).getQuantity();
 
-  if (!FAMDecl || !FAMDecl->hasAttr<CountedByAttr>())
+  if (!FAMDecl || !FAMDecl->getType()->getAs<CountAttributedType>())
     // No flexible array member found or it doesn't have the "counted_by"
     // attribute.
     return nullptr;
diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp
index c2c01439f2dc99..07ecaa81c47d84 100644
--- a/clang/lib/CodeGen/CGDebugInfo.cpp
+++ b/clang/lib/CodeGen/CGDebugInfo.cpp
@@ -3658,6 +3658,7 @@ llvm::DIType *CGDebugInfo::CreateTypeNode(QualType Ty, llvm::DIFile *Unit) {
   case Type::TemplateSpecialization:
     return CreateType(cast<TemplateSpecializationType>(Ty), Unit);
 
+  case Type::CountAttributed:
   case Type::Auto:
   case Type::Attributed:
   case Type::BTFTagAttributed:
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 59a7fe8925001c..08d5f935917908 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1136,38 +1136,19 @@ llvm::Value *CodeGenFunction::EmitCountedByFieldExpr(
 }
 
 const FieldDecl *CodeGenFunction::FindCountedByField(const FieldDecl *FD) {
-  if (!FD || !FD->hasAttr<CountedByAttr>())
+  if (!FD)
     return nullptr;
 
-  const auto *CBA = FD->getAttr<CountedByAttr>();
-  if (!CBA)
+  auto *CAT = FD->getType()->getAs<CountAttributedType>();
+  if (!CAT)
     return nullptr;
 
-  auto GetNonAnonStructOrUnion =
-      [](const RecordDecl *RD) -> const RecordDecl * {
-    while (RD && RD->isAnonymousStructOrUnion()) {
-      const auto *R = dyn_cast<RecordDecl>(RD->getDeclContext());
-      if (!R)
-        return nullptr;
-      RD = R;
-    }
-    return RD;
-  };
-  const RecordDecl *EnclosingRD = GetNonAnonStructOrUnion(FD->getParent());
-  if (!EnclosingRD)
-    return nullptr;
-
-  DeclarationName DName(CBA->getCountedByField());
-  DeclContext::lookup_result Lookup = EnclosingRD->lookup(DName);
-
-  if (Lookup.empty())
-    return nullptr;
-
-  const NamedDecl *ND = Lookup.front();
-  if (const auto *IFD = dyn_cast<IndirectFieldDecl>(ND))
-    ND = IFD->getAnonField();
+  auto *CountDRE = cast<DeclRefExpr>(CAT->getCountExpr());
+  auto *CountDecl = CountDRE->getDecl();
+  if (const auto *IFD = dyn_cast<IndirectFieldDecl>(CountDecl))
+    CountDecl = IFD->getAnonField();
 
-  return dyn_cast<FieldDecl>(ND);
+  return dyn_cast<FieldDecl>(CountDecl);
 }
 
 void CodeGenFunction::EmitBoundsCheck(const Expr *E, const Expr *Base,
@@ -4259,7 +4240,7 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
       if (const auto *ME = dyn_cast<MemberExpr>(Array);
           ME &&
           ME->isFlexibleArrayMemberLike(getContext(), StrictFlexArraysLevel) &&
-          ME->getMemberDecl()->hasAttr<CountedByAttr>()) {
+          ME->getMemberDecl()->getType()->getAs<CountAttributedType>()) {
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
         if (const FieldDecl *CountFD = FindCountedByField(FAMDecl)) {
           if (std::optional<int64_t> Diff =
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index b87fc86f4e635f..e245ad6afaa5e7 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -2406,6 +2406,7 @@ void CodeGenFunction::EmitVariablyModifiedType(QualType type) {
     case Type::BTFTagAttributed:
     case Type::SubstTemplateTypeParm:
     case Type::MacroQualified:
+    case Type::CountAttributed:
       // Keep walking after single level desugaring.
       type = type.getSingleStepDesugaredType(getContext());
       break;
diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp
index 64b234eb460d24..b820807bc94f6a 100644
--- a/clang/lib/Parse/ParseDecl.cpp
+++ b/clang/lib/Parse/ParseDecl.cpp
@@ -629,6 +629,10 @@ void Parser::ParseGNUAttributeArgs(
     ParseAttributeWithTypeArg(*AttrName, AttrNameLoc, Attrs, ScopeName,
                               ScopeLoc, Form);
     return;
+  } else if (AttrKind == ParsedAttr::AT_CountedBy) {
+    ParseBoundsAttribute(*AttrName, AttrNameLoc, Attrs, ScopeName, ScopeLoc,
+                         Form);
+    return;
   }
 
   // These may refer to the function arguments, but need to be parsed early to
@@ -3247,6 +3251,54 @@ void Parser::ParseAlignmentSpecifier(ParsedAttributes &Attrs,
   }
 }
 
+/// Bounds attributes (e.g., counted_by):
+///   AttrName '(' expression ')'
+void Parser::ParseBoundsAttribute(IdentifierInfo &AttrName,
+                                  SourceLocation AttrNameLoc,
+                                  ParsedAttributes &Attrs,
+                                  IdentifierInfo *ScopeName,
+                                  SourceLocation ScopeLoc,
+                                  ParsedAttr::Form Form) {
+  assert(Tok.is(tok::l_paren) && "Attribute arg list not starting with '('");
+
+  BalancedDelimiterTracker Parens(*this, tok::l_paren);
+  Parens.consumeOpen();
+
+  if (Tok.is(tok::r_paren)) {
+    Diag(Tok.getLocation(), diag::err_argument_required_after_attribute);
+    Parens.consumeClose();
+    return;
+  }
+
+  ArgsVector ArgExprs;
+  // Don't evaluate argument when the attribute is ignored.
+  using ExpressionKind =
+      Sema::ExpressionEvaluationContextRecord::ExpressionKind;
+  EnterExpressionEvaluationContext EC(
+      Actions, Sema::ExpressionEvaluationContext::PotentiallyEvaluated, nullptr,
+      ExpressionKind::EK_BoundsAttrArgument);
+
+  ExprResult ArgExpr(
+      Actions.CorrectDelayedTyposInExpr(ParseAssignmentExpression()));
+
+  if (ArgExpr.isInvalid()) {
+    Parens.skipToEnd();
+    return;
+  }
+
+  ArgExprs.push_back(ArgExpr.get());
+  Parens.consumeClose();
+
+  ASTContext &Ctx = Actions.getASTContext();
+
+  ArgExprs.push_back(IntegerLiteral::Create(
+      Ctx, llvm::APInt(Ctx.getTypeSize(Ctx.getSizeType()), 0),
+      Ctx.getSizeType(), SourceLocation()));
+
+  Attrs.addNew(&AttrName, SourceRange(AttrNameLoc, Parens.getCloseLocation()),
+               ScopeName, ScopeLoc, ArgExprs.data(), ArgExprs.size(), Form);
+}
+
 ExprResult Parser::ParseExtIntegerArgument() {
   assert(Tok.isOneOf(tok::kw__ExtInt, tok::kw__BitInt) &&
          "Not an extended int type");
@@ -4676,6 +4728,39 @@ void Parser::ParseDeclarationSpecifiers(
   }
 }
 
+static void DiagnoseCountAttributedTypeInUnnamedAnon(ParsingDeclSpec &DS,
+                                                     Parser &P) {
+
+  if (DS.getTypeSpecType() != DeclSpec::TST_struct)
+    return;
+
+  auto *RD = dyn_cast<RecordDecl>(DS.getRepAsDecl());
+  // We're only interested in unnamed, non-anonymous struct
+  if (!RD || !RD->getName().empty() || RD->isAnonymousStructOrUnion())
+    return;
+
+  for (auto *I : RD->decls()) {
+    auto *VD = dyn_cast<ValueDecl>(I);
+    if (!VD)
+      continue;
+
+    auto *CAT = VD->getType()->getAs<CountAttributedType>();
+    if (!CAT)
+      continue;
+
+    for (const auto &DD : CAT->dependent_decls()) {
+      if (!RD->containsDecl(DD.getDecl())) {
+        P.Diag(VD->getBeginLoc(),
+               diag::err_flexible_array_count_not_in_same_struct)
+            << DD.getDecl();
+        P.Diag(DD.getDecl()->getBeginLoc(),
+               diag::note_flexible_array_counted_by_attr_field)
+            << DD.getDecl();
+      }
+    }
+  }
+}
+
 /// ParseStructDeclaration - Parse a struct declaration without the terminating
 /// semicolon.
 ///
@@ -4756,6 +4841,11 @@ void Parser::ParseStructDeclaration(
     } else
       DeclaratorInfo.D.SetIdentifier(nullptr, Tok.getLocation());
 
+    // Here, we now know that the unnamed struct is not an anonymous struct.
+    // Report an error if a counted_by attribute refers to a field in a
+    // different named struct.
+    DiagnoseCountAttributedTypeInUnnamedAnon(DS, *this);
+
     if (TryConsumeToken(tok::colon)) {
       ExprResult Res(ParseConstantExpression());
       if (Res.isInvalid())
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 67e56a917a51de..b7aacfa0830e67 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -2270,12 +2270,6 @@ void Sema::ActOnPopScope(SourceLocation Loc, Scope *S) {
       }
       ShadowingDecls.erase(ShadowI);
     }
-
-    if (!getLangOpts().CPlusPlus && S->isClassScope()) {
-      if (auto *FD = dyn_cast<FieldDecl>(TmpD);
-          FD && FD->hasAttr<CountedByAttr>())
-        CheckCountedByAttr(S, FD);
-    }
   }
 
   llvm::sort(DeclDiags,
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index e6943efb345ce0..fdf9dea6fbcc5a 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -8472,133 +8472,110 @@ static void handleZeroCallUsedRegsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   D->addAttr(ZeroCallUsedRegsAttr::Create(S.Context, Kind, AL));
 }
 
-static void handleCountedByAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
-  if (!AL.isArgIdent(0)) {
-    S.Diag(AL.getLoc(), diag::err_attribute_argument_type)
-        << AL << AANT_ArgumentIdentifier;
-    return;
+static const RecordDecl *GetEnclosingNamedOrTopAnonRecord(const FieldDecl *FD) {
+  const auto *RD = FD->getParent();
+  while (RD && (RD->isAnonymousStructOrUnion() || RD->getName().empty())) {
+    const auto *Parent = dyn_cast<RecordDecl>(RD->getParent());
+    if (!Parent)
+      break;
+    RD = Parent;
   }
-
-  IdentifierLoc *IL = AL.getArgAsIdent(0);
-  CountedByAttr *CBA =
-      ::new (S.Context) CountedByAttr(S.Context, AL, IL->Ident);
-  CBA->setCountedByFieldLoc(IL->Loc);
-  D->addAttr(CBA);
+  return RD;
 }
 
-static const FieldDecl *
-FindFieldInTopLevelOrAnonymousStruct(const RecordDecl *RD,
-                                     const IdentifierInfo *FieldName) {
-  for (const Decl *D : RD->decls()) {
-    if (const auto *FD = dyn_cast<FieldDecl>(D))
-      if (FD->getName() == FieldName->getName())
-        return FD;
-
-    if (const auto *R = dyn_cast<RecordDecl>(D))
-      if (const FieldDecl *FD =
-              FindFieldInTopLevelOrAnonymousStruct(R, FieldName))
-        return FD;
+static bool
+CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
+               llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
+  if (FD->getParent()->isUnion()) {
+    S.Diag(FD->getBeginLoc(), diag::err_counted_by_attr_in_union)
+        << FD->getSourceRange();
+    return true;
   }
 
-  return nullptr;
-}
+  if (!E->getType()->isIntegerType() || E->getType()->isBooleanType()) {
+    S.Diag(E->getBeginLoc(), diag::err_counted_by_attr_argument_not_integer)
+        << E->getSourceRange();
+    return true;
+  }
 
-bool Sema::CheckCountedByAttr(Scope *S, const FieldDecl *FD) {
   LangOptions::StrictFlexArraysLevelKind StrictFlexArraysLevel =
       LangOptions::StrictFlexArraysLevelKind::IncompleteOnly;
-  if (!Decl::isFlexibleArrayMemberLike(Context, FD, FD->getType(),
+
+  if (!Decl::isFlexibleArrayMemberLike(S.getASTContext(), FD, FD->getType(),
                                        StrictFlexArraysLevel, true)) {
     // The "counted_by" attribute must be on a flexible array member.
     SourceRange SR = FD->getLocation();
-    Diag(SR.getBegin(), diag::err_counted_by_attr_not_on_flexible_array_member)
+    S.Diag(SR.getBegin(),
+           diag::err_counted_by_attr_not_on_flexible_array_member)
         << SR;
     return true;
   }
 
-  const auto *CBA = FD->getAttr<CountedByAttr>();
-  const IdentifierInfo *FieldName = CBA->getCountedByField();
-
-  auto GetNonAnonStructOrUnion = [](const RecordDecl *RD) {
-    while (RD && !RD->getDeclName())
-      if (const auto *R = dyn_cast<RecordDecl>(RD->getDeclContext()))
-        RD = R;
-      else
-        break;
-
-    return RD;
-  };
-
-  const RecordDecl *EnclosingRD = GetNonAnonStructOrUnion(FD->getParent());
-  const FieldDecl *CountFD =
-      FindFieldInTopLevelOrAnonymousStruct(EnclosingRD, FieldName);
+  auto *DRE = dyn_cast<DeclRefExpr>(E);
+  if (!DRE) {
+    S.Diag(E->getBeginLoc(),
+           diag::err_counted_by_attr_only_support_simple_decl_reference)
+        << E->getSourceRange();
+    return true;
+  }
 
+  auto *CountDecl = DRE->getDecl();
+  FieldDecl *CountFD = dyn_cast<FieldDecl>(CountDecl);
+  if (auto *IFD = dyn_cast<IndirectFieldDecl>(CountDecl)) {
+    CountFD = IFD->getAnonField();
+  }
   if (!CountFD) {
-    DeclarationNameInfo NameInfo(FieldName,
-                                 CBA->getCountedByFieldLoc().getBegin());
-    LookupResult MemResult(*this, NameInfo, Sema::LookupMemberName);
-    LookupName(MemResult, S);
-
-    if (!MemResult.empty()) {
-      SourceRange SR = CBA->getCountedByFieldLoc();
-      Diag(SR.getBegin(), diag::err_flexible_array_count_not_in_same_struct)
-          << CBA->getCountedByField() << SR;
-
-      if (auto *ND = MemResult.getAsSingle<NamedDecl>()) {
-        SR = ND->getLocation();
-        Diag(SR.getBegin(), diag::note_flexible_array_counted_by_attr_field)
-            << ND << SR;
-      }
+    S.Diag(E->getBeginLoc(), diag::err_counted_by_must_be_in_structure)
+        << CountDecl << E->getSourceRange();
 
-      return true;
-    } else {
-      // The "counted_by" field needs to exist in the struct.
-      LookupResult OrdResult(*this, NameInfo, Sema::LookupOrdinaryName);
-      LookupName(OrdResult, S);
-
-      if (!OrdResult.empty()) {
-        SourceRange SR = FD->getLocation();
-        Diag(SR.getBegin(), diag::err_counted_by_must_be_in_structure)
-            << FieldName << SR;
-
-        if (auto *ND = OrdResult.getAsSingle<NamedDecl>()) {
-          SR = ND->getLocation();
-          Diag(SR.getBegin(), diag::note_flexible_array_counted_by_attr_field)
-              << ND << SR;
-        }
+    S.Diag(CountDecl->getBeginLoc(),
+           diag::note_flexible_array_counted_by_attr_field)
+        << CountDecl << CountDecl->getSourceRange();
+    return true;
+  }
 
-        return true;
-      }
+  if (FD->getParent() != CountFD->getParent()) {
+    if (CountFD->getParent()->isUnion()) {
+      S.Diag(CountFD->getBeginLoc(), diag::err_counted_by_attr_refer_to_union)
+          << CountFD->getSourceRange();
+      return true;
+    }
+    // Whether CountRD is an anonymous struct is not determined at this
+    // point. Thus, an additional check is done later in
+    // `Parser::ParseStructDeclaration`.
+    auto *RD = GetEnclosingNamedOrTopAnonRecord(FD);
+    auto *CountRD = GetEnclosingNamedOrTopAnonRecord(CountFD);
+
+    if (RD != CountRD) {
+      S.Diag(E->getBeginLoc(),
+             diag::err_flexible_array_count_not_in_same_struct)
+          << CountFD << E->getSourceRange();
+      S.Diag(CountFD->getBeginLoc(),
+             diag::note_flexible_array_counted_by_attr_field)
+          << CountFD << CountFD->getSourceRange();
+      return true;
     }
-
-    CXXScopeSpec SS;
-    DeclFilterCCC<FieldDecl> Filter(FieldName);
-    return DiagnoseEmptyLookup(S, SS, MemResult, Filter, nullptr, std::nullopt,
-                               const_cast<DeclContext *>(FD->getDeclContext()));
   }
 
-  if (CountFD->hasAttr<CountedByAttr>()) {
-    // The "counted_by" field can't point to the flexible array member.
-    SourceRange SR = CBA->getCountedByFieldLoc();
-    Diag(SR.getBegin(), diag::err_counted_by_attr_refers_to_flexible_array)
-        << CBA->getCountedByField() << SR;
-    return true;
-  }
+  Decls.push_back(TypeCoupledDeclRefInfo(CountFD, /*IsDref*/ false));
+  return false;
+}
 
-  if (!CountFD->getType()->isIntegerType() ||
-      CountFD->getType()->isBooleanType()) {
-    // The "counted_by" field must have an integer type.
-    SourceRange SR = CBA->getCountedByFieldLoc();
-    Diag(SR.getBegin(),
-         diag::err_flexible_array_counted_by_attr_field_not_integer)
-        << CBA->getCountedByField() << SR;
+static void handleCountedByAttrField(Sema &S, Decl *D, const ParsedAttr &AL) {
+  auto *FD = dyn_cast<FieldDecl>(D);
+  assert(FD);
 
-    SR = CountFD->getLocation();
-    Diag(SR.getBegin(), diag::note_flexible_array_counted_by_attr_field)
-        << CountFD << SR;
-    return true;
-  }
+  auto *CountExpr = AL.getArgAsExpr(0);
+  if (!CountExpr)
+    return;
 
-  return false;
+  llvm::SmallVector<TypeCoupledDeclRefInfo, 1> Decls;
+  if (CheckCountExpr(S, FD, CountExpr, Decls))
+    return;
+
+  QualType CAT =
+      S.BuildCountAttributedArrayType(FD->getType(), CountExpr, Decls);
+  FD->setType(CAT);
 }
 
 static void handleFunctionReturnThunksAttr(Sema &S, Decl *D,
@@ -9617,7 +9594,7 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
     break;
 
   case ParsedAttr::AT_CountedBy:
-    handleCountedByAttr(S, D, AL);
+    handleCountedByAttrField(S, D, AL);
     break;
 
   // Microsoft attributes:
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 47bb263f56aade..04ccdf81a29c21 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -2738,6 +2738,24 @@ Sema::ActOnIdExpression(Scope *S, CXXScopeSpec &SS,
     return ActOnDependentIdExpression(SS, TemplateKWLoc, NameInfo,
                                       IsAddressOfOperand, TemplateArgs);
 
+  // BoundsSafety: This specially handles arguments of bounds attributes
+  // appertains to a type of C struct field such that the name lookup
+  // within a struct finds the member name, which is not the case for other
+  // contexts in C.
+  if (isBoundsAttrContext() && !getLangOpts().CPlusPlus && S->isClassScope()) {
+    // See if this is reference to a field of struct.
+    LookupResult R(*this, NameInfo, LookupMemberName);
+    // LookupParsedName handles a name lookup from within anonymous struct.
+    if (LookupParsedName(R, S, &SS)) {
+      if (auto *VD = dyn_cast<ValueDecl>(R.getFoundDecl())) {
+        QualType type = VD->getType().getNonReferenceType();
+        // This will eventually be translated into MemberExpr upon
+        // the use of instantiated struct fields.
+        return BuildDeclRefExpr(VD, type, VK_PRValue, NameLoc);
+      }
+    }
+  }
+
   // Perform the required lookup.
   LookupResult R(*this, NameInfo,
                  (Id.getKind() == UnqualifiedIdKind::IK_ImplicitSelfParam)
@@ -2894,7 +2912,8 @@ Sema::ActOnIdExpression(Scope *S, CXXScopeSpec &SS,
   // to get this right here so that we don't end up making a
   // spuriously dependent expression if we're inside a dependent
   // instance method.
-  if (!R.empty() && (*R.begin())->isCXXClassMember()) {
+  if (getLangOpts().CPlusPlus && !R.empty() &&
+      (*R.begin())->isCXXClassMember()) {
     bool MightBeImplicitMember;
     if (!IsAddressOfOperand)
       MightBeImplicitMember = true;
@@ -3550,7 +3569,8 @@ ExprResult Sema::BuildDeclarationNameExpr(
   case Decl::Field:
   case Decl::IndirectField:
   case Decl::ObjCIvar:
-    assert(getLangOpts().CPlusPlus && "building reference to field in C?");
+    assert((getLangOpts().CPlusPlus || isBoundsAttrContext()) &&
+           "building reference to field in C?");
 
     // These can't have reference type in well-formed programs, but
     // for internal consistency we do this anyway.
@@ -4707,6 +4727,7 @@ static void captureVariablyModifiedType(ASTContext &Context, QualType T,
     case Type::BTFTagAttributed:
     case Type::SubstTemplateTypeParm:
     case Type::MacroQualified:
+    case Type::CountAttributed:
       // Keep walking after single level desugaring.
       T = T.getSingleStepDesugaredType(Context);
       break;
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 3148299f6467af..5038f721a91a1d 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -9745,6 +9745,18 @@ QualType Sema::BuildTypeofExprType(Expr *E, TypeOfKind Kind) {
   return Context.getTypeOfExprType(E, Kind);
 }
 
+QualType Sema::BuildCountAttributedArrayType(
+    QualType WrappedTy, Expr *CountExpr,
+    const llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
+  assert(WrappedTy->isIncompleteArrayType());
+
+  /// When the resulting expression is invalid, we still create the AST using
+  /// the original count expression for the sake of AST dump.
+  return Context.getCountAttributedType(
+      WrappedTy, CountExpr, /*CountInBytes*/ false, /*OrNull*/ false,
+      llvm::ArrayRef(Decls.begin(), Decls.end()));
+}
+
 /// getDecltypeForExpr - Given an expr, will return the decltype for
 /// that expression, according to the rules in C++11
 /// [dcl.type.simple]p4 and C++11 [expr.lambda.prim]p18.
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 409aee73d960eb..f474c61b3b64ab 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -7261,6 +7261,13 @@ QualType TreeTransform<Derived>::TransformAttributedType(TypeLocBuilder &TLB,
       });
 }
 
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformCountAttributedType(
+    TypeLocBuilder &TLB, CountAttributedTypeLoc TL) {
+  // TODO
+  llvm_unreachable("Unexpected TreeTransform for CountAttributedType");
+}
+
 template <typename Derived>
 QualType TreeTransform<Derived>::TransformBTFTagAttributedType(
     TypeLocBuilder &TLB, BTFTagAttributedTypeLoc TL) {
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 683a076e6bc399..b52149d63ab819 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -6995,6 +6995,10 @@ void TypeLocReader::VisitAttributedTypeLoc(AttributedTypeLoc TL) {
   TL.setAttr(ReadAttr());
 }
 
+void TypeLocReader::VisitCountAttributedTypeLoc(CountAttributedTypeLoc TL) {
+  // nothing to do
+}
+
 void TypeLocReader::VisitBTFTagAttributedTypeLoc(BTFTagAttributedTypeLoc TL) {
   // Nothing to do.
 }
@@ -9116,6 +9120,10 @@ DeclarationNameInfo ASTRecordReader::readDeclarationNameInfo() {
   return NameInfo;
 }
 
+TypeCoupledDeclRefInfo ASTRecordReader::readTypeCoupledDeclRefInfo() {
+  return TypeCoupledDeclRefInfo(readDeclAs<ValueDecl>(), readBool());
+}
+
 void ASTRecordReader::readQualifierInfo(QualifierInfo &Info) {
   Info.QualifierLoc = readNestedNameSpecifierLoc();
   unsigned NumTPLists = readInt();
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index a9edc7e68b53b3..f00959ab4b2a70 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -514,6 +514,10 @@ void TypeLocWriter::VisitAttributedTypeLoc(AttributedTypeLoc TL) {
   Record.AddAttr(TL.getAttr());
 }
 
+void TypeLocWriter::VisitCountAttributedTypeLoc(CountAttributedTypeLoc TL) {
+  // nothing to do
+}
+
 void TypeLocWriter::VisitBTFTagAttributedTypeLoc(BTFTagAttributedTypeLoc TL) {
   // Nothing to do.
 }
diff --git a/clang/test/Sema/attr-counted-by.c b/clang/test/Sema/attr-counted-by.c
index f14da9c77fa8b4..a3b08df7309944 100644
--- a/clang/test/Sema/attr-counted-by.c
+++ b/clang/test/Sema/attr-counted-by.c
@@ -11,7 +11,7 @@ struct not_found {
 
 struct no_found_count_not_in_substruct {
   unsigned long flags;
-  unsigned char count; // expected-note {{field 'count' declared here}}
+  unsigned char count; // expected-note {{'count' declared here}}
   struct A {
     int dummy;
     int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
@@ -19,8 +19,8 @@ struct no_found_count_not_in_substruct {
 };
 
 struct not_found_suggest {
-  int bork; // expected-note {{'bork' declared here}}
-  struct bar *fam[] __counted_by(blork); // expected-error {{use of undeclared identifier 'blork'; did you mean 'bork'?}}
+  int bork;
+  struct bar *fam[] __counted_by(blork); // expected-error {{use of undeclared identifier 'blork'}}
 };
 
 int global; // expected-note {{'global' declared here}}
@@ -32,17 +32,17 @@ struct found_outside_of_struct {
 
 struct self_referrential {
   int bork;
-  struct bar *self[] __counted_by(self); // expected-error {{'counted_by' cannot refer to the flexible array 'self'}}
+  struct bar *self[] __counted_by(self); // expected-error {{use of undeclared identifier 'self'}}
 };
 
 struct non_int_count {
-  double dbl_count; // expected-note {{field 'dbl_count' declared here}}
-  struct bar *fam[] __counted_by(dbl_count); // expected-error {{field 'dbl_count' in 'counted_by' must be a non-boolean integer type}}
+  double dbl_count;
+  struct bar *fam[] __counted_by(dbl_count); // expected-error {{'counted_by' requires a non-boolean integer type argument}}
 };
 
 struct array_of_ints_count {
-  int integers[2]; // expected-note {{field 'integers' declared here}}
-  struct bar *fam[] __counted_by(integers); // expected-error {{field 'integers' in 'counted_by' must be a non-boolean integer type}}
+  int integers[2];
+  struct bar *fam[] __counted_by(integers); // expected-error {{'counted_by' requires a non-boolean integer type argument}}
 };
 
 struct not_a_fam {
@@ -58,7 +58,7 @@ struct not_a_c99_fam {
 struct annotated_with_anon_struct {
   unsigned long flags;
   struct {
-    unsigned char count; // expected-note {{'count' declared here}}
-    int array[] __counted_by(crount); // expected-error {{use of undeclared identifier 'crount'; did you mean 'count'?}}
+    unsigned char count;
+    int array[] __counted_by(crount); // expected-error {{use of undeclared identifier 'crount'}}
   };
 };
diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp
index 418b152ba4a13f..454ee1e42aed1d 100644
--- a/clang/tools/libclang/CIndex.cpp
+++ b/clang/tools/libclang/CIndex.cpp
@@ -1778,6 +1778,10 @@ bool CursorVisitor::VisitAttributedTypeLoc(AttributedTypeLoc TL) {
   return Visit(TL.getModifiedLoc());
 }
 
+bool CursorVisitor::VisitCountAttributedTypeLoc(CountAttributedTypeLoc TL) {
+  return Visit(TL.getInnerLoc());
+}
+
 bool CursorVisitor::VisitBTFTagAttributedTypeLoc(BTFTagAttributedTypeLoc TL) {
   return Visit(TL.getWrappedLoc());
 }

>From 39f4c0560b67fc74882c8f6741dd1c7e7756fb0e Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Thu, 1 Feb 2024 16:29:05 -0800
Subject: [PATCH 2/9] Update/add tests

---
 .../Modules/bounds-safety-attributed-type.c   | 27 +++++++++++
 .../Inputs/bounds-safety-attributed-type.h    |  4 ++
 .../test/PCH/bounds-safety-attributed-type.c  | 17 +++++++
 clang/test/Sema/attr-counted-by.c             | 48 +++++++++++++++++++
 4 files changed, 96 insertions(+)
 create mode 100644 clang/test/Modules/bounds-safety-attributed-type.c
 create mode 100644 clang/test/PCH/Inputs/bounds-safety-attributed-type.h
 create mode 100644 clang/test/PCH/bounds-safety-attributed-type.c

diff --git a/clang/test/Modules/bounds-safety-attributed-type.c b/clang/test/Modules/bounds-safety-attributed-type.c
new file mode 100644
index 00000000000000..aa657ee09699dd
--- /dev/null
+++ b/clang/test/Modules/bounds-safety-attributed-type.c
@@ -0,0 +1,27 @@
+// RUN: rm -rf %t
+// RUN: %clang_cc1 -fmodules -fmodules-cache-path=%t -verify %s
+// RUN: %clang_cc1 -fmodules -fmodules-cache-path=%t -ast-dump-all %s | FileCheck %s
+// expected-no-diagnostics
+
+#pragma clang module build bounds_safety
+module bounds_safety {}
+#pragma clang module contents
+#pragma clang module begin bounds_safety
+struct Test {
+  int count;
+  int fam[] __attribute__((counted_by(count)));
+};
+#pragma clang module end
+#pragma clang module endbuild
+
+#pragma clang module import bounds_safety
+
+struct Test *p;
+
+// CHECK: |-RecordDecl {{.*/}}bounds_safety.map:4:1, line:7:1> line:4:8 imported in bounds_safety <undeserialized declarations> struct Test definition
+// CHECK: | |-FieldDecl {{.*}} imported in bounds_safety referenced count 'int'
+// CHECK: | `-FieldDecl {{.*}} imported in bounds_safety fam 'int[] __counted_by(count)':'int[]'
+
+// CHECK: |-ImportDecl {{.*}}bounds-safety-attributed-type.c:17:22> col:22 implicit bounds_safety
+// CHECK: |-RecordDecl {{.*}} struct Test
+// CHECK: `-VarDecl {{.*}} p 'struct Test *'
diff --git a/clang/test/PCH/Inputs/bounds-safety-attributed-type.h b/clang/test/PCH/Inputs/bounds-safety-attributed-type.h
new file mode 100644
index 00000000000000..9bfd9615c2bfc1
--- /dev/null
+++ b/clang/test/PCH/Inputs/bounds-safety-attributed-type.h
@@ -0,0 +1,4 @@
+struct Test {
+  int count;
+  int fam[] __attribute__((counted_by(count)));
+};
diff --git a/clang/test/PCH/bounds-safety-attributed-type.c b/clang/test/PCH/bounds-safety-attributed-type.c
new file mode 100644
index 00000000000000..c2dc0272bbd605
--- /dev/null
+++ b/clang/test/PCH/bounds-safety-attributed-type.c
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 -include %S/Inputs/bounds-safety-attributed-type.h -fsyntax-only -verify %s
+
+// Test with pch.
+// RUN: %clang_cc1 -emit-pch -o %t %S/Inputs/bounds-safety-attributed-type.h
+// RUN: %clang_cc1 -include-pch %t -fsyntax-only -verify %s
+// RUN: %clang_cc1 -include-pch %t -ast-print %s | FileCheck %s --check-prefix PRINT
+// RUN: %clang_cc1 -include-pch %t -ast-dump-all %s | FileCheck %s --check-prefix DUMP
+// expected-no-diagnostics
+
+// PRINT: 	   struct Test {
+// PRINT-NEXT:   int count;
+// PRINT-NEXT:   int fam[] __counted_by(count);
+// PRINT-NEXT: };
+
+// DUMP:        RecordDecl {{.*}} imported <undeserialized declarations> struct Test definition
+// DUMP-NEXT:   |-FieldDecl {{.*}} imported referenced count 'int'
+// DUMP-NEXT:   `-FieldDecl {{.*}} imported fam 'int[] __counted_by(count)':'int[]'
diff --git a/clang/test/Sema/attr-counted-by.c b/clang/test/Sema/attr-counted-by.c
index a3b08df7309944..1b646254911cd2 100644
--- a/clang/test/Sema/attr-counted-by.c
+++ b/clang/test/Sema/attr-counted-by.c
@@ -18,6 +18,54 @@ struct no_found_count_not_in_substruct {
   } a;
 };
 
+struct not_found_count_not_in_unnamed_substruct {
+  unsigned char count; // expected-note {{'count' declared here}}
+  struct {
+    int dummy;
+    int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
+  } a;
+};
+
+struct not_found_count_not_in_unnamed_substruct_2 {
+  struct {
+    unsigned char count; // expected-note {{'count' declared here}}
+  };
+  struct {
+    int dummy;
+    int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
+  } a;
+};
+
+struct not_found_count_in_other_unnamed_substruct {
+  struct {
+    unsigned char count;
+  } a1;
+
+  struct {
+    int dummy;
+    int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
+  };
+};
+
+struct not_found_count_in_other_substruct {
+  struct _a1 {
+    unsigned char count; // expected-note {{'count' declared here}}
+  } a1;
+
+  struct {
+    int dummy;
+    int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
+  };
+};
+
+struct not_found_count_in_other_substruct_2 {
+  struct _a1 {
+    unsigned char count; // expected-note {{'count' declared here}}
+  } a1;
+
+  int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
+};
+
 struct not_found_suggest {
   int bork;
   struct bar *fam[] __counted_by(blork); // expected-error {{use of undeclared identifier 'blork'}}

>From c501bfd4d42b122b77f0e1535a6af1e6768b43e1 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Mon, 19 Feb 2024 16:46:49 -0800
Subject: [PATCH 3/9] Fix misuse of SmallVector as parameters; fix logic to
 check anonymous structs

---
 clang/lib/Sema/SemaDeclAttr.cpp | 17 +++++++++++------
 clang/lib/Sema/SemaType.cpp     |  5 ++---
 2 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index fdf9dea6fbcc5a..29529b6fcc01f5 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -8474,7 +8474,12 @@ static void handleZeroCallUsedRegsAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
 
 static const RecordDecl *GetEnclosingNamedOrTopAnonRecord(const FieldDecl *FD) {
   const auto *RD = FD->getParent();
-  while (RD && (RD->isAnonymousStructOrUnion() || RD->getName().empty())) {
+  // An unnamed struct is anonymous struct only if it's not instantiated.
+  // However, the struct may not be fully processed yet to determine
+  // whether it's anonymous or not. In that case, this function treats it as
+  // an anonymous struct and tries to find a named parent.
+  while (RD && (RD->isAnonymousStructOrUnion() ||
+                (!RD->isCompleteDefinition() && RD->getName().empty()))) {
     const auto *Parent = dyn_cast<RecordDecl>(RD->getParent());
     if (!Parent)
       break;
@@ -8485,7 +8490,7 @@ static const RecordDecl *GetEnclosingNamedOrTopAnonRecord(const FieldDecl *FD) {
 
 static bool
 CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
-               llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
+               llvm::SmallVectorImpl<TypeCoupledDeclRefInfo> &Decls) {
   if (FD->getParent()->isUnion()) {
     S.Diag(FD->getBeginLoc(), diag::err_counted_by_attr_in_union)
         << FD->getSourceRange();
@@ -8541,8 +8546,8 @@ CheckCountExpr(Sema &S, FieldDecl *FD, Expr *E,
       return true;
     }
     // Whether CountRD is an anonymous struct is not determined at this
-    // point. Thus, an additional check is done later in
-    // `Parser::ParseStructDeclaration`.
+    // point. Thus, an additional diagnostic in case it's not anonymous struct
+    // is done later in `Parser::ParseStructDeclaration`.
     auto *RD = GetEnclosingNamedOrTopAnonRecord(FD);
     auto *CountRD = GetEnclosingNamedOrTopAnonRecord(CountFD);
 
@@ -8573,8 +8578,8 @@ static void handleCountedByAttrField(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (CheckCountExpr(S, FD, CountExpr, Decls))
     return;
 
-  QualType CAT =
-      S.BuildCountAttributedArrayType(FD->getType(), CountExpr, Decls);
+  QualType CAT = S.BuildCountAttributedArrayType(
+      FD->getType(), CountExpr, llvm::ArrayRef(Decls.begin(), Decls.end()));
   FD->setType(CAT);
 }
 
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 5038f721a91a1d..b25be5e5584a26 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -9747,14 +9747,13 @@ QualType Sema::BuildTypeofExprType(Expr *E, TypeOfKind Kind) {
 
 QualType Sema::BuildCountAttributedArrayType(
     QualType WrappedTy, Expr *CountExpr,
-    const llvm::SmallVector<TypeCoupledDeclRefInfo, 1> &Decls) {
+    llvm::ArrayRef<TypeCoupledDeclRefInfo> Decls) {
   assert(WrappedTy->isIncompleteArrayType());
 
   /// When the resulting expression is invalid, we still create the AST using
   /// the original count expression for the sake of AST dump.
   return Context.getCountAttributedType(
-      WrappedTy, CountExpr, /*CountInBytes*/ false, /*OrNull*/ false,
-      llvm::ArrayRef(Decls.begin(), Decls.end()));
+      WrappedTy, CountExpr, /*CountInBytes*/ false, /*OrNull*/ false, Decls);
 }
 
 /// getDecltypeForExpr - Given an expr, will return the decltype for

>From 863375df3a09851a859721672802708f022fe184 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Mon, 19 Feb 2024 16:50:19 -0800
Subject: [PATCH 4/9] Add comment on BoundsAttributedType::classof

---
 clang/include/clang/AST/Type.h | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 3e336f524ca4f6..2454c2b011ac15 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -3010,6 +3010,9 @@ class BoundsAttributedType : public Type, public llvm::FoldingSetNode {
   bool referencesFieldDecls() const;
 
   static bool classof(const Type *T) {
+    // Currently, only `class CountAttributedType` inherits
+    // `BoundsAttributedType` but the subclass will grow as we add more bounds
+    // annotations.
     switch (T->getTypeClass()) {
     case CountAttributed:
       return true;

>From 939f6d58c0b5e6dfb9fc3e76ed8ba4edd027cfcd Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Mon, 26 Feb 2024 13:47:17 -0800
Subject: [PATCH 5/9] Update test/Sema/attr-counted-by.c

---
 clang/test/Sema/attr-counted-by.c | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/clang/test/Sema/attr-counted-by.c b/clang/test/Sema/attr-counted-by.c
index 1b646254911cd2..d5d4ebf5573922 100644
--- a/clang/test/Sema/attr-counted-by.c
+++ b/clang/test/Sema/attr-counted-by.c
@@ -43,27 +43,27 @@ struct not_found_count_in_other_unnamed_substruct {
 
   struct {
     int dummy;
-    int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
+    int array[] __counted_by(count); // expected-error {{use of undeclared identifier 'count'}}
   };
 };
 
 struct not_found_count_in_other_substruct {
   struct _a1 {
-    unsigned char count; // expected-note {{'count' declared here}}
+    unsigned char count;
   } a1;
 
   struct {
     int dummy;
-    int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
+    int array[] __counted_by(count); // expected-error {{use of undeclared identifier 'count'}}
   };
 };
 
 struct not_found_count_in_other_substruct_2 {
-  struct _a1 {
-    unsigned char count; // expected-note {{'count' declared here}}
-  } a1;
+  struct _a2 {
+    unsigned char count;
+  } a2;
 
-  int array[] __counted_by(count); // expected-error {{'counted_by' field 'count' isn't within the same struct as the flexible array}}
+  int array[] __counted_by(count); // expected-error {{use of undeclared identifier 'count'}}
 };
 
 struct not_found_suggest {

>From 3e40ddf007483a0de205dd9f3d2d995e69191ef6 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Mon, 26 Feb 2024 13:53:47 -0800
Subject: [PATCH 6/9] Handle 'attr::CountedBy' in TypePrinter

---
 clang/lib/AST/TypePrinter.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 2c61fa05506434..ad97478ad9c418 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -1877,6 +1877,7 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
     // AttributedType nodes for them.
     break;
 
+  case attr::CountedBy:
   case attr::LifetimeBound:
   case attr::TypeNonNull:
   case attr::TypeNullable:

>From 0f4424ed42202eba7563a982c39799951c30d7c7 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Wed, 6 Mar 2024 17:18:14 -0800
Subject: [PATCH 7/9] Implement TreeTransform for CountAttributedType; address
 other style suggestions

---
 clang/include/clang/AST/Type.h        |  3 ++-
 clang/include/clang/AST/TypeLoc.h     |  6 ++----
 clang/lib/AST/ASTContext.cpp          | 12 ++++++------
 clang/lib/AST/Type.cpp                | 10 ++++++----
 clang/lib/CodeGen/CGBuiltin.cpp       |  4 ++--
 clang/lib/CodeGen/CGExpr.cpp          |  8 ++++----
 clang/lib/Sema/SemaDeclAttr.cpp       |  3 +--
 clang/lib/Sema/SemaType.cpp           | 16 +++++++++++++---
 clang/lib/Sema/TreeTransform.h        | 25 +++++++++++++++++++++++--
 clang/lib/Serialization/ASTReader.cpp |  2 +-
 clang/lib/Serialization/ASTWriter.cpp |  2 +-
 11 files changed, 61 insertions(+), 30 deletions(-)

diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index 2454c2b011ac15..86b58b38f5b439 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -2007,7 +2007,6 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     LLVM_PREFERRED_TYPE(TypeBitfields)
     unsigned : NumTypeBits;
 
-    /// The limit is 15.
     static constexpr unsigned NumCoupledDeclsBits = 4;
     unsigned NumCoupledDecls : NumCoupledDeclsBits;
     LLVM_PREFERRED_TYPE(bool)
@@ -2015,6 +2014,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
     LLVM_PREFERRED_TYPE(bool)
     unsigned OrNull : 1;
   };
+  static_assert(sizeof(CountAttributedTypeBitfields) <= sizeof(unsigned));
 
   union {
     TypeBitfields TypeBits;
@@ -2279,6 +2279,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isFunctionProtoType() const { return getAs<FunctionProtoType>(); }
   bool isPointerType() const;
   bool isAnyPointerType() const;   // Any C pointer or ObjC object pointer
+  bool isCountAttributedType() const;
   bool isBlockPointerType() const;
   bool isVoidPointerType() const;
   bool isReferenceType() const;
diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h
index 6415a3bb0b4e4f..861da531b9a664 100644
--- a/clang/include/clang/AST/TypeLoc.h
+++ b/clang/include/clang/AST/TypeLoc.h
@@ -1125,13 +1125,11 @@ class BoundsAttributedTypeLoc
     : public ConcreteTypeLoc<UnqualTypeLoc, BoundsAttributedTypeLoc,
                              BoundsAttributedType, BoundsAttributedLocInfo> {
 public:
-  TypeLoc getInnerLoc() const { return this->getInnerTypeLoc(); }
-  QualType getInnerType() const { return this->getTypePtr()->desugar(); }
+  TypeLoc getInnerLoc() const { return getInnerTypeLoc(); }
+  QualType getInnerType() const { return getTypePtr()->desugar(); }
   void initializeLocal(ASTContext &Context, SourceLocation Loc) {
     // nothing to do
   }
-  unsigned getLocalDataAlignment() const { return 1; }
-  unsigned getLocalDataSize() const { return 0; }
 };
 
 class CountAttributedTypeLoc final
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 5247e29ea30204..fcf801adeaf5ef 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -13270,9 +13270,9 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X,
       return QualType();
     if (DX->isOrNull() != DY->isOrNull())
       return QualType();
-    const auto CEX = DX->getCountExpr();
-    const auto CEY = DY->getCountExpr();
-    const auto CDX = DX->getCoupledDecls();
+    Expr *CEX = DX->getCountExpr();
+    Expr *CEY = DY->getCountExpr();
+    llvm::ArrayRef<clang::TypeCoupledDeclRefInfo> CDX = DX->getCoupledDecls();
     if (Ctx.hasSameExpr(CEX, CEY))
       return Ctx.getCountAttributedType(Ctx.getQualifiedType(Underlying), CEX,
                                         DX->isCountInBytes(), DX->isOrNull(),
@@ -13281,9 +13281,9 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X,
       return QualType();
     // Two declarations with the same integer constant may still differ in their
     // expression pointers, so we need to evaluate them.
-    const auto VX = CEX->getIntegerConstantExpr(Ctx);
-    const auto VY = CEY->getIntegerConstantExpr(Ctx);
-    if (*VX != *VY)
+    llvm::APSInt VX = *CEX->getIntegerConstantExpr(Ctx);
+    llvm::APSInt VY = *CEY->getIntegerConstantExpr(Ctx);
+    if (VX != VY)
       return QualType();
     return Ctx.getCountAttributedType(Ctx.getQualifiedType(Underlying), CEX,
                                       DX->isCountInBytes(), DX->isOrNull(),
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 89947f0e901323..b033cd7fa1ba44 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -386,10 +386,8 @@ void DependentBitIntType::Profile(llvm::FoldingSetNodeID &ID,
 }
 
 bool BoundsAttributedType::referencesFieldDecls() const {
-  for (const auto &Decl : dependent_decls())
-    if (isa<FieldDecl>(Decl.getDecl()))
-      return true;
-  return false;
+  return llvm::any_of(dependent_decls(), [](const TypeCoupledDeclRefInfo &Info) {
+    return isa<FieldDecl>(Info.getDecl()); });
 }
 
 void CountAttributedType::Profile(llvm::FoldingSetNodeID &ID,
@@ -669,6 +667,10 @@ bool Type::isScopedEnumeralType() const {
   return false;
 }
 
+bool Type::isCountAttributedType() const {
+  return getAs<CountAttributedType>();
+}
+
 const ComplexType *Type::getAsComplexIntegerType() const {
   if (const auto *Complex = getAs<ComplexType>())
     if (Complex->getElementType()->isIntegerType())
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 8a2810fea21054..25a0df880db508 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -858,7 +858,7 @@ static unsigned CountCountedByAttrs(const RecordDecl *RD) {
 
   for (const Decl *D : RD->decls()) {
     if (const auto *FD = dyn_cast<FieldDecl>(D);
-        FD && FD->getType()->getAs<CountAttributedType>()) {
+        FD && FD->getType()->isCountAttributedType()) {
       return ++Num;
     }
 
@@ -975,7 +975,7 @@ CodeGenFunction::emitFlexibleArrayMemberSize(const Expr *E, unsigned Type,
       FindFlexibleArrayMemberField(Ctx, OuterRD, FAMName, Offset);
   Offset = Ctx.toCharUnitsFromBits(Offset).getQuantity();
 
-  if (!FAMDecl || !FAMDecl->getType()->getAs<CountAttributedType>())
+  if (!FAMDecl || !FAMDecl->getType()->isCountAttributedType())
     // No flexible array member found or it doesn't have the "counted_by"
     // attribute.
     return nullptr;
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 08d5f935917908..85f5d739cef457 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1139,12 +1139,12 @@ const FieldDecl *CodeGenFunction::FindCountedByField(const FieldDecl *FD) {
   if (!FD)
     return nullptr;
 
-  auto *CAT = FD->getType()->getAs<CountAttributedType>();
+  const auto *CAT = FD->getType()->getAs<CountAttributedType>();
   if (!CAT)
     return nullptr;
 
-  auto *CountDRE = cast<DeclRefExpr>(CAT->getCountExpr());
-  auto *CountDecl = CountDRE->getDecl();
+  const auto *CountDRE = cast<DeclRefExpr>(CAT->getCountExpr());
+  const auto *CountDecl = CountDRE->getDecl();
   if (const auto *IFD = dyn_cast<IndirectFieldDecl>(CountDecl))
     CountDecl = IFD->getAnonField();
 
@@ -4240,7 +4240,7 @@ LValue CodeGenFunction::EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
       if (const auto *ME = dyn_cast<MemberExpr>(Array);
           ME &&
           ME->isFlexibleArrayMemberLike(getContext(), StrictFlexArraysLevel) &&
-          ME->getMemberDecl()->getType()->getAs<CountAttributedType>()) {
+          ME->getMemberDecl()->getType()->isCountAttributedType()) {
         const FieldDecl *FAMDecl = dyn_cast<FieldDecl>(ME->getMemberDecl());
         if (const FieldDecl *CountFD = FindCountedByField(FAMDecl)) {
           if (std::optional<int64_t> Diff =
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 29529b6fcc01f5..3198f52a15b80d 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -8578,8 +8578,7 @@ static void handleCountedByAttrField(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (CheckCountExpr(S, FD, CountExpr, Decls))
     return;
 
-  QualType CAT = S.BuildCountAttributedArrayType(
-      FD->getType(), CountExpr, llvm::ArrayRef(Decls.begin(), Decls.end()));
+  QualType CAT = S.BuildCountAttributedArrayType(FD->getType(), CountExpr);
   FD->setType(CAT);
 }
 
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index b25be5e5584a26..467a5d7f91faf9 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -6494,6 +6494,9 @@ namespace {
     void VisitAttributedTypeLoc(AttributedTypeLoc TL) {
       fillAttributedTypeLoc(TL, State);
     }
+    void VisitCountAttributedTypeLoc(CountAttributedTypeLoc TL) {
+      // nothing
+    }
     void VisitBTFTagAttributedTypeLoc(BTFTagAttributedTypeLoc TL) {
       // nothing
     }
@@ -9745,11 +9748,18 @@ QualType Sema::BuildTypeofExprType(Expr *E, TypeOfKind Kind) {
   return Context.getTypeOfExprType(E, Kind);
 }
 
-QualType Sema::BuildCountAttributedArrayType(
-    QualType WrappedTy, Expr *CountExpr,
-    llvm::ArrayRef<TypeCoupledDeclRefInfo> Decls) {
+static void BuildTypeCoupledDecls(Expr *E,
+                                  llvm::SmallVectorImpl<TypeCoupledDeclRefInfo> &Decls) {
+  // Currently, 'counted_by' only allows direct DeclRefExpr to FieldDecl.
+  auto *CountDecl = cast<DeclRefExpr>(E)->getDecl();
+  Decls.push_back(TypeCoupledDeclRefInfo(CountDecl, /*IsDref*/ false));
+}
+
+QualType Sema::BuildCountAttributedArrayType(QualType WrappedTy, Expr *CountExpr) {
   assert(WrappedTy->isIncompleteArrayType());
 
+  llvm::SmallVector<TypeCoupledDeclRefInfo, 1> Decls;
+  BuildTypeCoupledDecls(CountExpr, Decls);
   /// When the resulting expression is invalid, we still create the AST using
   /// the original count expression for the sake of AST dump.
   return Context.getCountAttributedType(
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index f474c61b3b64ab..e1246bb48a69ee 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -7264,8 +7264,29 @@ QualType TreeTransform<Derived>::TransformAttributedType(TypeLocBuilder &TLB,
 template <typename Derived>
 QualType TreeTransform<Derived>::TransformCountAttributedType(
     TypeLocBuilder &TLB, CountAttributedTypeLoc TL) {
-  // TODO
-  llvm_unreachable("Unexpected TreeTransform for CountAttributedType");
+  const CountAttributedType *OldTy = TL.getTypePtr();
+  QualType InnerTy = getDerived().TransformType(TLB, TL.getInnerLoc());
+  if (InnerTy.isNull())
+    return QualType();
+
+  Expr *OldCount = TL.getCountExpr();
+  Expr *NewCount = nullptr;
+  if (OldCount) {
+    ExprResult CountResult = getDerived().TransformExpr(OldCount);
+    if (CountResult.isInvalid())
+      return QualType();
+    NewCount = CountResult.get();
+  }
+
+  QualType Result = TL.getType();
+  if (getDerived().AlwaysRebuild() || InnerTy != OldTy->desugar() ||
+      OldCount != NewCount) {
+    // Currently, CountAttributedType can only wrap incomplete array types.
+    Result = SemaRef.BuildCountAttributedArrayType(InnerTy, NewCount);
+  }
+
+  TLB.push<CountAttributedTypeLoc>(Result);
+  return Result;
 }
 
 template <typename Derived>
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index b52149d63ab819..575d535da3d6de 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -6996,7 +6996,7 @@ void TypeLocReader::VisitAttributedTypeLoc(AttributedTypeLoc TL) {
 }
 
 void TypeLocReader::VisitCountAttributedTypeLoc(CountAttributedTypeLoc TL) {
-  // nothing to do
+  // Nothing to do
 }
 
 void TypeLocReader::VisitBTFTagAttributedTypeLoc(BTFTagAttributedTypeLoc TL) {
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index f00959ab4b2a70..1221087129a4c8 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -515,7 +515,7 @@ void TypeLocWriter::VisitAttributedTypeLoc(AttributedTypeLoc TL) {
 }
 
 void TypeLocWriter::VisitCountAttributedTypeLoc(CountAttributedTypeLoc TL) {
-  // nothing to do
+  // Nothing to do
 }
 
 void TypeLocWriter::VisitBTFTagAttributedTypeLoc(BTFTagAttributedTypeLoc TL) {

>From ce6480414c9effd307dd4bd657e14bd2d14dac97 Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Wed, 6 Mar 2024 17:27:17 -0800
Subject: [PATCH 8/9] clang-format

---
 clang/lib/AST/Type.cpp      | 6 ++++--
 clang/lib/Sema/SemaType.cpp | 8 +++++---
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index b033cd7fa1ba44..8edb76149ea5b4 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -386,8 +386,10 @@ void DependentBitIntType::Profile(llvm::FoldingSetNodeID &ID,
 }
 
 bool BoundsAttributedType::referencesFieldDecls() const {
-  return llvm::any_of(dependent_decls(), [](const TypeCoupledDeclRefInfo &Info) {
-    return isa<FieldDecl>(Info.getDecl()); });
+  return llvm::any_of(dependent_decls(),
+                      [](const TypeCoupledDeclRefInfo &Info) {
+                        return isa<FieldDecl>(Info.getDecl());
+                      });
 }
 
 void CountAttributedType::Profile(llvm::FoldingSetNodeID &ID,
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 467a5d7f91faf9..d28e4b19b56b20 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -9748,14 +9748,16 @@ QualType Sema::BuildTypeofExprType(Expr *E, TypeOfKind Kind) {
   return Context.getTypeOfExprType(E, Kind);
 }
 
-static void BuildTypeCoupledDecls(Expr *E,
-                                  llvm::SmallVectorImpl<TypeCoupledDeclRefInfo> &Decls) {
+static void
+BuildTypeCoupledDecls(Expr *E,
+                      llvm::SmallVectorImpl<TypeCoupledDeclRefInfo> &Decls) {
   // Currently, 'counted_by' only allows direct DeclRefExpr to FieldDecl.
   auto *CountDecl = cast<DeclRefExpr>(E)->getDecl();
   Decls.push_back(TypeCoupledDeclRefInfo(CountDecl, /*IsDref*/ false));
 }
 
-QualType Sema::BuildCountAttributedArrayType(QualType WrappedTy, Expr *CountExpr) {
+QualType Sema::BuildCountAttributedArrayType(QualType WrappedTy,
+                                             Expr *CountExpr) {
   assert(WrappedTy->isIncompleteArrayType());
 
   llvm::SmallVector<TypeCoupledDeclRefInfo, 1> Decls;

>From 8f184166b2e1bb992120fb48d28624d1baa0495c Mon Sep 17 00:00:00 2001
From: Yeoul Na <yeoul_na at apple.com>
Date: Thu, 7 Mar 2024 10:04:09 -0800
Subject: [PATCH 9/9] Reapply Sema.h changes after rebase

---
 clang/include/clang/Sema/Sema.h | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 592c7871a4a55d..a2d136eef0b80b 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -3918,8 +3918,6 @@ class Sema final {
                                             bool BestCase,
                                             MSInheritanceModel Model);
 
-  bool CheckCountedByAttr(Scope *Scope, const FieldDecl *FD);
-
   EnforceTCBAttr *mergeEnforceTCBAttr(Decl *D, const EnforceTCBAttr &AL);
   EnforceTCBLeafAttr *mergeEnforceTCBLeafAttr(Decl *D,
                                               const EnforceTCBLeafAttr &AL);
@@ -5232,6 +5230,7 @@ class Sema final {
     enum ExpressionKind {
       EK_Decltype,
       EK_TemplateArgument,
+      EK_BoundsAttrArgument,
       EK_Other
     } ExprContext;
 
@@ -5349,6 +5348,12 @@ class Sema final {
     return ExprEvalContexts.back();
   };
 
+  bool isBoundsAttrContext() const {
+    return ExprEvalContexts.back().ExprContext ==
+           ExpressionEvaluationContextRecord::ExpressionKind::
+               EK_BoundsAttrArgument;
+  }
+
   /// Increment when we find a reference; decrement when we find an ignored
   /// assignment.  Ultimately the value is 0 if every reference is an ignored
   /// assignment.
@@ -11715,6 +11720,8 @@ class Sema final {
   QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns,
                            SourceLocation AttrLoc);
 
+  QualType BuildCountAttributedArrayType(QualType WrappedTy, Expr *CountExpr);
+
   QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace,
                                  SourceLocation AttrLoc);
 



More information about the cfe-commits mailing list