[clang] [clang/AST] Make it possible to use SwiftAttr in type context (PR #108631)

Pavel Yaskevich via cfe-commits cfe-commits at lists.llvm.org
Wed Oct 30 14:10:31 PDT 2024


https://github.com/xedin updated https://github.com/llvm/llvm-project/pull/108631

>From a66d820647c927dee23498cd12338748f4079688 Mon Sep 17 00:00:00 2001
From: Pavel Yaskevich <xedin at apache.org>
Date: Wed, 20 Dec 2023 14:04:22 -0800
Subject: [PATCH] [clang/AST] Make it possible to use SwiftAttr in type context

Swift ClangImporter now supports concurrency annotations on imported declarations
and their parameters/results, to make it possible to use imported APIs in Swift
safely there has to be a way to annotate individual parameters and result types
with relevant attributes that indicate that e.g. a block is called on a particular
actor or it accepts a `Sendable` parameter.

To faciliate that `SwiftAttr` is switched from `InheritableAttr` which is a declaration
attribute to `DeclOrTypeAttr`. To support this attribute in type context we need access
to its "Attribute" argument which requires `AttributedType` to be extended to include
`Attr *` when available instead of just `attr::Kind` otherwise it won't be possible to
determine what attribute should be imported.
---
 clang/docs/ReleaseNotes.rst                   | 13 ++++
 clang/include/clang/AST/ASTContext.h          |  7 ++
 clang/include/clang/AST/PropertiesBase.td     |  1 +
 clang/include/clang/AST/Type.h                | 42 +++++-------
 clang/include/clang/AST/TypeProperties.td     |  8 ++-
 clang/include/clang/Basic/Attr.td             |  2 +-
 clang/include/clang/Basic/AttrDocs.td         |  4 +-
 .../clang/Serialization/ASTRecordWriter.h     |  2 +
 clang/lib/AST/ASTContext.cpp                  | 48 ++++++++++---
 clang/lib/AST/ASTDiagnostic.cpp               |  6 +-
 clang/lib/AST/ASTImporter.cpp                 |  5 +-
 clang/lib/AST/Type.cpp                        | 20 +++++-
 clang/lib/AST/TypePrinter.cpp                 |  9 +++
 clang/lib/Sema/SemaDecl.cpp                   |  4 +-
 clang/lib/Sema/SemaDeclObjC.cpp               |  4 +-
 clang/lib/Sema/SemaExpr.cpp                   |  3 +-
 clang/lib/Sema/SemaExprObjC.cpp               | 29 +++-----
 clang/lib/Sema/SemaObjCProperty.cpp           |  4 +-
 clang/lib/Sema/SemaSwift.cpp                  |  7 +-
 clang/lib/Sema/SemaType.cpp                   | 67 +++++++++++++++++--
 clang/lib/Sema/TreeTransform.h                |  3 +-
 clang/test/AST/attr-swift_attr.m              | 47 ++++++++++++-
 .../test/SemaObjC/validate-attr-swift_attr.m  |  4 ++
 23 files changed, 255 insertions(+), 84 deletions(-)

diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 1a179e63f902f3..5d4d5b2319688f 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -338,6 +338,19 @@ Removed Compiler Flags
 Attribute Changes in Clang
 --------------------------
 
+- The ``swift_attr`` can now be applied to types. To make it possible to use imported APIs
+  in Swift safely there has to be a way to annotate individual parameters and result types
+  with relevant attributes that indicate that e.g. a block is called on a particular actor
+  or it accepts a Sendable or global-actor (i.e. ``@MainActor``) isolated parameter.
+
+  For example:
+
+  .. code-block:: objc
+
+     @interface MyService
+       -(void) handle: (void (^ __attribute__((swift_attr("@Sendable"))))(id)) handler;
+     @end
+
 - Clang now disallows more than one ``__attribute__((ownership_returns(class, idx)))`` with
   different class names attached to one function.
 
diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index 07b4e36f3ef05e..4c1455a3e1bbf2 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -1719,8 +1719,15 @@ class ASTContext : public RefCountedBase<ASTContext> {
   QualType getInjectedClassNameType(CXXRecordDecl *Decl, QualType TST) const;
 
   QualType getAttributedType(attr::Kind attrKind, QualType modifiedType,
+                             QualType equivalentType,
+                             const Attr *attr = nullptr) const;
+
+  QualType getAttributedType(const Attr *attr, QualType modifiedType,
                              QualType equivalentType) const;
 
+  QualType getAttributedType(NullabilityKind nullability, QualType modifiedType,
+                             QualType equivalentType);
+
   QualType getBTFTagAttributedType(const BTFTypeTagAttr *BTFAttr,
                                    QualType Wrapped) const;
 
diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td
index 3057669e3758b5..5f3a885832e2e4 100644
--- a/clang/include/clang/AST/PropertiesBase.td
+++ b/clang/include/clang/AST/PropertiesBase.td
@@ -76,6 +76,7 @@ def APValue : PropertyType { let PassByReference = 1; }
 def APValueKind : EnumPropertyType<"APValue::ValueKind">;
 def ArraySizeModifier : EnumPropertyType<"ArraySizeModifier">;
 def AttrKind : EnumPropertyType<"attr::Kind">;
+def Attr : PropertyType<"const Attr *">;
 def AutoTypeKeyword : EnumPropertyType;
 def Bool : PropertyType<"bool">;
 def BuiltinTypeKind : EnumPropertyType<"BuiltinType::Kind">;
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index ba3161c366f4d9..1bcc7ee0b70dee 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -69,6 +69,7 @@ class ValueDecl;
 class TagDecl;
 class TemplateParameterList;
 class Type;
+class Attr;
 
 enum {
   TypeAlignmentInBits = 4,
@@ -6130,21 +6131,29 @@ class AttributedType : public Type, public llvm::FoldingSetNode {
 private:
   friend class ASTContext; // ASTContext creates these
 
+  const Attr *Attribute;
+
   QualType ModifiedType;
   QualType EquivalentType;
 
   AttributedType(QualType canon, attr::Kind attrKind, QualType modified,
                  QualType equivalent)
-      : Type(Attributed, canon, equivalent->getDependence()),
-        ModifiedType(modified), EquivalentType(equivalent) {
-    AttributedTypeBits.AttrKind = attrKind;
-  }
+      : AttributedType(canon, attrKind, nullptr, modified, equivalent) {}
+
+  AttributedType(QualType canon, const Attr *attr, QualType modified,
+                 QualType equivalent);
+
+private:
+  AttributedType(QualType canon, attr::Kind attrKind, const Attr *attr,
+                 QualType modified, QualType equivalent);
 
 public:
   Kind getAttrKind() const {
     return static_cast<Kind>(AttributedTypeBits.AttrKind);
   }
 
+  const Attr *getAttr() const { return Attribute; }
+
   QualType getModifiedType() const { return ModifiedType; }
   QualType getEquivalentType() const { return EquivalentType; }
 
@@ -6176,25 +6185,6 @@ class AttributedType : public Type, public llvm::FoldingSetNode {
 
   std::optional<NullabilityKind> getImmediateNullability() const;
 
-  /// Retrieve the attribute kind corresponding to the given
-  /// nullability kind.
-  static Kind getNullabilityAttrKind(NullabilityKind kind) {
-    switch (kind) {
-    case NullabilityKind::NonNull:
-      return attr::TypeNonNull;
-
-    case NullabilityKind::Nullable:
-      return attr::TypeNullable;
-
-    case NullabilityKind::NullableResult:
-      return attr::TypeNullableResult;
-
-    case NullabilityKind::Unspecified:
-      return attr::TypeNullUnspecified;
-    }
-    llvm_unreachable("Unknown nullability kind.");
-  }
-
   /// Strip off the top-level nullability annotation on the given
   /// type, if it's there.
   ///
@@ -6207,14 +6197,16 @@ class AttributedType : public Type, public llvm::FoldingSetNode {
   static std::optional<NullabilityKind> stripOuterNullability(QualType &T);
 
   void Profile(llvm::FoldingSetNodeID &ID) {
-    Profile(ID, getAttrKind(), ModifiedType, EquivalentType);
+    Profile(ID, getAttrKind(), ModifiedType, EquivalentType, Attribute);
   }
 
   static void Profile(llvm::FoldingSetNodeID &ID, Kind attrKind,
-                      QualType modified, QualType equivalent) {
+                      QualType modified, QualType equivalent,
+                      const Attr *attr) {
     ID.AddInteger(attrKind);
     ID.AddPointer(modified.getAsOpaquePtr());
     ID.AddPointer(equivalent.getAsOpaquePtr());
+    ID.AddPointer(attr);
   }
 
   static bool classof(const Type *T) {
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index d05072607e949c..42f62695963a2d 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -668,12 +668,16 @@ let Class = AttributedType in {
   def : Property<"equivalentType", QualType> {
     let Read = [{ node->getEquivalentType() }];
   }
-  def : Property<"attribute", AttrKind> {
+  def : Property<"attrKind", AttrKind> {
     let Read = [{ node->getAttrKind() }];
   }
+  def : Property<"attribute", Attr> {
+    let Read = [{ node->getAttr() }];
+  }
 
   def : Creator<[{
-    return ctx.getAttributedType(attribute, modifiedType, equivalentType);
+    return ctx.getAttributedType(attrKind, modifiedType,
+                                 equivalentType, attribute);
   }]>;
 }
 
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 47c93b48175fc8..156fbd1c4442eb 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -2838,7 +2838,7 @@ def SwiftAsyncName : InheritableAttr {
   let Documentation = [SwiftAsyncNameDocs];
 }
 
-def SwiftAttr : InheritableAttr {
+def SwiftAttr : DeclOrTypeAttr {
   let Spellings = [GNU<"swift_attr">];
   let Args = [StringArgument<"Attribute">];
   let Documentation = [SwiftAttrDocs];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index fbbfc4acdf391e..b497cce37625c9 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -4507,8 +4507,8 @@ def SwiftAttrDocs : Documentation {
   let Heading = "swift_attr";
   let Content = [{
 The ``swift_attr`` provides a Swift-specific annotation for the declaration
-to which the attribute appertains to. It can be used on any declaration
-in Clang. This kind of annotation is ignored by Clang as it doesn't have any
+or type to which the attribute appertains to. It can be used on any declaration
+or type in Clang. This kind of annotation is ignored by Clang as it doesn't have any
 semantic meaning in languages supported by Clang. The Swift compiler can
 interpret these annotations according to its own rules when importing C or
 Objective-C declarations.
diff --git a/clang/include/clang/Serialization/ASTRecordWriter.h b/clang/include/clang/Serialization/ASTRecordWriter.h
index 0c8ac75fc40f40..d6090ba1a6c690 100644
--- a/clang/include/clang/Serialization/ASTRecordWriter.h
+++ b/clang/include/clang/Serialization/ASTRecordWriter.h
@@ -128,6 +128,8 @@ class ASTRecordWriter
     AddStmt(const_cast<Stmt*>(S));
   }
 
+  void writeAttr(const Attr *A) { AddAttr(A); }
+
   /// Write an BTFTypeTagAttr object.
   void writeBTFTypeTagAttr(const BTFTypeTagAttr *A) { AddAttr(A); }
 
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 1c3f771f417ccf..d248084666d1b0 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -3552,7 +3552,8 @@ ASTContext::adjustType(QualType Orig,
     const auto *AT = dyn_cast<AttributedType>(Orig);
     return getAttributedType(AT->getAttrKind(),
                              adjustType(AT->getModifiedType(), Adjust),
-                             adjustType(AT->getEquivalentType(), Adjust));
+                             adjustType(AT->getEquivalentType(), Adjust),
+                             AT->getAttr());
   }
 
   case Type::BTFTagAttributed: {
@@ -5197,17 +5198,20 @@ QualType ASTContext::getUnresolvedUsingType(
 
 QualType ASTContext::getAttributedType(attr::Kind attrKind,
                                        QualType modifiedType,
-                                       QualType equivalentType) const {
+                                       QualType equivalentType,
+                                       const Attr *attr) const {
   llvm::FoldingSetNodeID id;
-  AttributedType::Profile(id, attrKind, modifiedType, equivalentType);
+  AttributedType::Profile(id, attrKind, modifiedType, equivalentType, attr);
 
   void *insertPos = nullptr;
   AttributedType *type = AttributedTypes.FindNodeOrInsertPos(id, insertPos);
   if (type) return QualType(type, 0);
 
+  assert(!attr || attr->getKind() == attrKind);
+
   QualType canon = getCanonicalType(equivalentType);
-  type = new (*this, alignof(AttributedType))
-      AttributedType(canon, attrKind, modifiedType, equivalentType);
+	type = new (*this, alignof(AttributedType))
+      AttributedType(canon, attrKind, attr, modifiedType, equivalentType);
 
   Types.push_back(type);
   AttributedTypes.InsertNode(type, insertPos);
@@ -5215,6 +5219,33 @@ QualType ASTContext::getAttributedType(attr::Kind attrKind,
   return QualType(type, 0);
 }
 
+QualType ASTContext::getAttributedType(const Attr *attr, QualType modifiedType,
+                                       QualType equivalentType) const {
+  return getAttributedType(attr->getKind(), modifiedType, equivalentType, attr);
+}
+
+QualType ASTContext::getAttributedType(NullabilityKind nullability,
+                                       QualType modifiedType,
+                                       QualType equivalentType) {
+  switch (nullability) {
+  case NullabilityKind::NonNull:
+    return getAttributedType(attr::TypeNonNull, modifiedType, equivalentType);
+
+  case NullabilityKind::Nullable:
+    return getAttributedType(attr::TypeNullable, modifiedType, equivalentType);
+
+  case NullabilityKind::NullableResult:
+    return getAttributedType(attr::TypeNullableResult, modifiedType,
+                             equivalentType);
+
+  case NullabilityKind::Unspecified:
+    return getAttributedType(attr::TypeNullUnspecified, modifiedType,
+                             equivalentType);
+  }
+
+  llvm_unreachable("Unknown nullability kind");
+}
+
 QualType ASTContext::getBTFTagAttributedType(const BTFTypeTagAttr *BTFAttr,
                                              QualType Wrapped) const {
   llvm::FoldingSetNodeID ID;
@@ -7537,8 +7568,8 @@ QualType ASTContext::getArrayDecayedType(QualType Ty) const {
 
   // int x[_Nullable] -> int * _Nullable
   if (auto Nullability = Ty->getNullability()) {
-    Result = const_cast<ASTContext *>(this)->getAttributedType(
-        AttributedType::getNullabilityAttrKind(*Nullability), Result, Result);
+    Result = const_cast<ASTContext *>(this)->getAttributedType(*Nullability,
+                                                               Result, Result);
   }
   return Result;
 }
@@ -13773,7 +13804,8 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X,
       return QualType();
     // FIXME: It's inefficient to have to unify the modified types.
     return Ctx.getAttributedType(Kind, Ctx.getCommonSugaredType(MX, MY),
-                                 Ctx.getQualifiedType(Underlying));
+                                 Ctx.getQualifiedType(Underlying),
+                                 AX->getAttr());
   }
   case Type::BTFTagAttributed: {
     const auto *BX = cast<BTFTagAttributedType>(X);
diff --git a/clang/lib/AST/ASTDiagnostic.cpp b/clang/lib/AST/ASTDiagnostic.cpp
index 15c3efe4212719..4f677b60e60dae 100644
--- a/clang/lib/AST/ASTDiagnostic.cpp
+++ b/clang/lib/AST/ASTDiagnostic.cpp
@@ -85,8 +85,7 @@ QualType clang::desugarForDiagnostic(ASTContext &Context, QualType QT,
       QualType SugarRT = FT->getReturnType();
       QualType RT = desugarForDiagnostic(Context, SugarRT, DesugarReturn);
       if (auto nullability = AttributedType::stripOuterNullability(SugarRT)) {
-        RT = Context.getAttributedType(
-            AttributedType::getNullabilityAttrKind(*nullability), RT, RT);
+        RT = Context.getAttributedType(*nullability, RT, RT);
       }
 
       bool DesugarArgument = false;
@@ -97,8 +96,7 @@ QualType clang::desugarForDiagnostic(ASTContext &Context, QualType QT,
           QualType PT = desugarForDiagnostic(Context, SugarPT, DesugarArgument);
           if (auto nullability =
                   AttributedType::stripOuterNullability(SugarPT)) {
-            PT = Context.getAttributedType(
-                AttributedType::getNullabilityAttrKind(*nullability), PT, PT);
+            PT = Context.getAttributedType(*nullability, PT, PT);
           }
           Args.push_back(PT);
         }
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index e7a6509167f0a0..6e31df691fa104 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1580,8 +1580,9 @@ ExpectedType ASTNodeImporter::VisitAttributedType(const AttributedType *T) {
   if (!ToEquivalentTypeOrErr)
     return ToEquivalentTypeOrErr.takeError();
 
-  return Importer.getToContext().getAttributedType(T->getAttrKind(),
-      *ToModifiedTypeOrErr, *ToEquivalentTypeOrErr);
+  return Importer.getToContext().getAttributedType(
+      T->getAttrKind(), *ToModifiedTypeOrErr, *ToEquivalentTypeOrErr,
+      T->getAttr());
 }
 
 ExpectedType
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index 113d4a100528f8..229721aeae8114 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -1241,8 +1241,8 @@ struct SimpleTransformVisitor : public TypeVisitor<Derived, QualType> {
           == T->getEquivalentType().getAsOpaquePtr())
       return QualType(T, 0);
 
-    return Ctx.getAttributedType(T->getAttrKind(), modifiedType,
-                                 equivalentType);
+    return Ctx.getAttributedType(T->getAttrKind(), modifiedType, equivalentType,
+                                 T->getAttr());
   }
 
   QualType VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *T) {
@@ -1545,7 +1545,8 @@ struct SubstObjCTypeArgsVisitor
 
     // Rebuild the attributed type.
     return Ctx.getAttributedType(newAttrType->getAttrKind(),
-                                 newAttrType->getModifiedType(), newEquivType);
+                                 newAttrType->getModifiedType(), newEquivType,
+                                 newAttrType->getAttr());
   }
 };
 
@@ -4115,6 +4116,19 @@ bool RecordType::hasConstFields() const {
   return false;
 }
 
+AttributedType::AttributedType(QualType canon, const Attr *attr,
+                               QualType modified, QualType equivalent)
+    : AttributedType(canon, attr->getKind(), attr, modified, equivalent) {}
+
+AttributedType::AttributedType(QualType canon, attr::Kind attrKind,
+                               const Attr *attr, QualType modified,
+                               QualType equivalent)
+    : Type(Attributed, canon, equivalent->getDependence()), Attribute(attr),
+      ModifiedType(modified), EquivalentType(equivalent) {
+  AttributedTypeBits.AttrKind = attrKind;
+  assert(!attr || attr->getKind() == attrKind);
+}
+
 bool AttributedType::isQualifier() const {
   // FIXME: Generate this with TableGen.
   switch (getAttrKind()) {
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index 008e87e7e5c14b..6d8db5cf4ffd22 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -1934,6 +1934,14 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
     return;
   }
 
+  if (T->getAttrKind() == attr::SwiftAttr) {
+    if (auto *swiftAttr = dyn_cast_or_null<SwiftAttrAttr>(T->getAttr())) {
+      OS << " __attribute__((swift_attr(\"" << swiftAttr->getAttribute()
+         << "\")))";
+    }
+    return;
+  }
+
   OS << " __attribute__((";
   switch (T->getAttrKind()) {
 #define TYPE_ATTR(NAME)
@@ -1994,6 +2002,7 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
   case attr::NonAllocating:
   case attr::Blocking:
   case attr::Allocating:
+  case attr::SwiftAttr:
     llvm_unreachable("This attribute should have been handled already");
 
   case attr::NSReturnsRetained:
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index f8e5f3c6d309d6..c56883a80c1c55 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -3332,9 +3332,7 @@ static void mergeParamDeclTypes(ParmVarDecl *NewParam,
       }
     } else {
       QualType NewT = NewParam->getType();
-      NewT = S.Context.getAttributedType(
-                         AttributedType::getNullabilityAttrKind(*Oldnullability),
-                         NewT, NewT);
+      NewT = S.Context.getAttributedType(*Oldnullability, NewT, NewT);
       NewParam->setType(NewT);
     }
   }
diff --git a/clang/lib/Sema/SemaDeclObjC.cpp b/clang/lib/Sema/SemaDeclObjC.cpp
index 78acfeddb78639..abda22c871b07a 100644
--- a/clang/lib/Sema/SemaDeclObjC.cpp
+++ b/clang/lib/Sema/SemaDeclObjC.cpp
@@ -4572,9 +4572,7 @@ static QualType mergeTypeNullabilityForRedecl(Sema &S, SourceLocation loc,
     return type;
 
   // Otherwise, provide the result with the same nullability.
-  return S.Context.getAttributedType(
-           AttributedType::getNullabilityAttrKind(*prevNullability),
-           type, type);
+  return S.Context.getAttributedType(*prevNullability, type, type);
 }
 
 /// Merge information from the declaration of a method in the \@interface
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..7f3cff1054aeed 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -8757,8 +8757,7 @@ static QualType computeConditionalNullability(QualType ResTy, bool IsBin,
     ResTy = ResTy.getSingleStepDesugaredType(Ctx);
 
   // Create a new AttributedType with the new nullability kind.
-  auto NewAttr = AttributedType::getNullabilityAttrKind(MergedKind);
-  return Ctx.getAttributedType(NewAttr, ResTy, ResTy);
+  return Ctx.getAttributedType(MergedKind, ResTy, ResTy);
 }
 
 ExprResult Sema::ActOnConditionalOp(SourceLocation QuestionLoc,
diff --git a/clang/lib/Sema/SemaExprObjC.cpp b/clang/lib/Sema/SemaExprObjC.cpp
index 35fbc4e7c30ebb..3fcbbb417ff1fa 100644
--- a/clang/lib/Sema/SemaExprObjC.cpp
+++ b/clang/lib/Sema/SemaExprObjC.cpp
@@ -551,9 +551,7 @@ ExprResult SemaObjC::BuildObjCBoxedExpr(SourceRange SR, Expr *ValueExpr) {
             const llvm::UTF8 *StrEnd = Str.bytes_end();
             // Check that this is a valid UTF-8 string.
             if (llvm::isLegalUTF8String(&StrBegin, StrEnd)) {
-              BoxedType = Context.getAttributedType(
-                  AttributedType::getNullabilityAttrKind(
-                      NullabilityKind::NonNull),
+              BoxedType = Context.getAttributedType(NullabilityKind::NonNull,
                   NSStringPointer, NSStringPointer);
               return new (Context) ObjCBoxedExpr(CE, BoxedType, nullptr, SR);
             }
@@ -605,9 +603,8 @@ ExprResult SemaObjC::BuildObjCBoxedExpr(SourceRange SR, Expr *ValueExpr) {
       std::optional<NullabilityKind> Nullability =
           BoxingMethod->getReturnType()->getNullability();
       if (Nullability)
-        BoxedType = Context.getAttributedType(
-            AttributedType::getNullabilityAttrKind(*Nullability), BoxedType,
-            BoxedType);
+        BoxedType =
+            Context.getAttributedType(*Nullability, BoxedType, BoxedType);
     }
   } else if (ValueType->isBuiltinType()) {
     // The other types we support are numeric, char and BOOL/bool. We could also
@@ -1444,10 +1441,8 @@ static QualType stripObjCInstanceType(ASTContext &Context, QualType T) {
   QualType origType = T;
   if (auto nullability = AttributedType::stripOuterNullability(T)) {
     if (T == Context.getObjCInstanceType()) {
-      return Context.getAttributedType(
-               AttributedType::getNullabilityAttrKind(*nullability),
-               Context.getObjCIdType(),
-               Context.getObjCIdType());
+      return Context.getAttributedType(*nullability, Context.getObjCIdType(),
+                                       Context.getObjCIdType());
     }
 
     return origType;
@@ -1485,10 +1480,7 @@ static QualType getBaseMessageSendResultType(Sema &S,
       (void)AttributedType::stripOuterNullability(type);
 
       // Form a new attributed type using the method result type's nullability.
-      return Context.getAttributedType(
-               AttributedType::getNullabilityAttrKind(*nullability),
-               type,
-               type);
+      return Context.getAttributedType(*nullability, type, type);
     }
 
     return type;
@@ -1559,9 +1551,8 @@ QualType SemaObjC::getMessageSendResultType(const Expr *Receiver,
         QualType NewResultType = Context.getObjCObjectPointerType(
             Context.getObjCInterfaceType(MD->getClassInterface()));
         if (auto Nullability = resultType->getNullability())
-          NewResultType = Context.getAttributedType(
-              AttributedType::getNullabilityAttrKind(*Nullability),
-              NewResultType, NewResultType);
+          NewResultType = Context.getAttributedType(*Nullability, NewResultType,
+                                                    NewResultType);
         return NewResultType;
       }
     }
@@ -1623,9 +1614,7 @@ QualType SemaObjC::getMessageSendResultType(const Expr *Receiver,
   if (newResultNullabilityIdx > 0) {
     auto newNullability
       = static_cast<NullabilityKind>(newResultNullabilityIdx-1);
-    return Context.getAttributedType(
-             AttributedType::getNullabilityAttrKind(newNullability),
-             resultType, resultType);
+    return Context.getAttributedType(newNullability, resultType, resultType);
   }
 
   return resultType;
diff --git a/clang/lib/Sema/SemaObjCProperty.cpp b/clang/lib/Sema/SemaObjCProperty.cpp
index 5c4dc93ff16cdf..6cef76fbde1d25 100644
--- a/clang/lib/Sema/SemaObjCProperty.cpp
+++ b/clang/lib/Sema/SemaObjCProperty.cpp
@@ -2460,7 +2460,7 @@ void SemaObjC::ProcessPropertyDecl(ObjCPropertyDecl *property) {
       QualType modifiedTy = resultTy;
       if (auto nullability = AttributedType::stripOuterNullability(modifiedTy)) {
         if (*nullability == NullabilityKind::Unspecified)
-          resultTy = Context.getAttributedType(attr::TypeNonNull,
+          resultTy = Context.getAttributedType(NullabilityKind::NonNull,
                                                modifiedTy, modifiedTy);
       }
     }
@@ -2538,7 +2538,7 @@ void SemaObjC::ProcessPropertyDecl(ObjCPropertyDecl *property) {
         QualType modifiedTy = paramTy;
         if (auto nullability = AttributedType::stripOuterNullability(modifiedTy)){
           if (*nullability == NullabilityKind::Unspecified)
-            paramTy = Context.getAttributedType(attr::TypeNullable,
+            paramTy = Context.getAttributedType(NullabilityKind::Nullable,
                                                 modifiedTy, modifiedTy);
         }
       }
diff --git a/clang/lib/Sema/SemaSwift.cpp b/clang/lib/Sema/SemaSwift.cpp
index 2eebce74b5e2f8..24fdfb8e57dc34 100644
--- a/clang/lib/Sema/SemaSwift.cpp
+++ b/clang/lib/Sema/SemaSwift.cpp
@@ -73,11 +73,16 @@ static bool isValidSwiftErrorResultType(QualType Ty) {
 }
 
 void SemaSwift::handleAttrAttr(Decl *D, const ParsedAttr &AL) {
+  if (AL.isInvalid() || AL.isUsedAsTypeAttr())
+    return;
+
   // Make sure that there is a string literal as the annotation's single
   // argument.
   StringRef Str;
-  if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str))
+  if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str)) {
+    AL.setInvalid();
     return;
+  }
 
   D->addAttr(::new (getASTContext()) SwiftAttrAttr(getASTContext(), AL, Str));
 }
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 5d043e6684573c..e526a11973975d 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -289,7 +289,7 @@ namespace {
     QualType getAttributedType(Attr *A, QualType ModifiedType,
                                QualType EquivType) {
       QualType T =
-          sema.Context.getAttributedType(A->getKind(), ModifiedType, EquivType);
+          sema.Context.getAttributedType(A, ModifiedType, EquivType);
       AttrsForTypes.push_back({cast<AttributedType>(T.getTypePtr()), A});
       AttrsForTypesSorted = false;
       return T;
@@ -7161,6 +7161,60 @@ static bool HandleWebAssemblyFuncrefAttr(TypeProcessingState &State,
   return false;
 }
 
+static void HandleSwiftAttr(TypeProcessingState &State, TypeAttrLocation TAL,
+                            QualType &QT, ParsedAttr &PAttr) {
+  if (TAL == TAL_DeclName)
+    return;
+
+  Sema &S = State.getSema();
+  auto &D = State.getDeclarator();
+
+  // If the attribute appears in declaration specifiers
+  // it should be handled as a declaration attribute,
+  // unless it's associated with a type or a function
+  // prototype (i.e. appears on a parameter or result type).
+  if (State.isProcessingDeclSpec()) {
+    if (!(D.isPrototypeContext() ||
+          D.getContext() == DeclaratorContext::TypeName))
+      return;
+
+    if (auto *chunk = D.getInnermostNonParenChunk()) {
+      moveAttrFromListToList(PAttr, State.getCurrentAttributes(),
+                             const_cast<DeclaratorChunk *>(chunk)->getAttrs());
+      return;
+    }
+  }
+
+  StringRef Str;
+  if (!S.checkStringLiteralArgumentAttr(PAttr, 0, Str)) {
+    PAttr.setInvalid();
+    return;
+  }
+
+  // If the attribute as attached to a paren move it closer to
+  // the declarator. This can happen in block declarations when
+  // an attribute is placed before `^` i.e. `(__attribute__((...)) ^)`.
+  //
+  // Note that it's actually invalid to use GNU style attributes
+  // in a block but such cases are currently handled gracefully
+  // but the parser and behavior should be consistent between
+  // cases when attribute appears before/after block's result
+  // type and inside (^).
+  if (TAL == TAL_DeclChunk) {
+    auto chunkIdx = State.getCurrentChunkIndex();
+    if (chunkIdx >= 1 &&
+        D.getTypeObject(chunkIdx).Kind == DeclaratorChunk::Paren) {
+      moveAttrFromListToList(PAttr, State.getCurrentAttributes(),
+                             D.getTypeObject(chunkIdx - 1).getAttrs());
+      return;
+    }
+  }
+
+  auto *A = ::new (S.Context) SwiftAttrAttr(S.Context, PAttr, Str);
+  QT = State.getAttributedType(A, QT, QT);
+  PAttr.setUsedAsTypeAttr();
+}
+
 /// Rebuild an attributed type without the nullability attribute on it.
 static QualType rebuildAttributedTypeWithoutNullability(ASTContext &Ctx,
                                                         QualType Type) {
@@ -7177,7 +7231,8 @@ static QualType rebuildAttributedTypeWithoutNullability(ASTContext &Ctx,
       Ctx, Attributed->getModifiedType());
   assert(Modified.getTypePtr() != Attributed->getModifiedType().getTypePtr());
   return Ctx.getAttributedType(Attributed->getAttrKind(), Modified,
-                               Attributed->getEquivalentType());
+                               Attributed->getEquivalentType(),
+                               Attributed->getAttr());
 }
 
 /// Map a nullability attribute kind to a nullability kind.
@@ -7306,8 +7361,7 @@ static bool CheckNullabilityTypeSpecifier(
     Attr *A = createNullabilityAttr(S.Context, *PAttr, Nullability);
     QT = State->getAttributedType(A, QT, QT);
   } else {
-    attr::Kind attrKind = AttributedType::getNullabilityAttrKind(Nullability);
-    QT = S.Context.getAttributedType(attrKind, QT, QT);
+    QT = S.Context.getAttributedType(Nullability, QT, QT);
   }
   return false;
 }
@@ -8749,6 +8803,11 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
       break;
     }
 
+    case ParsedAttr::AT_SwiftAttr: {
+      HandleSwiftAttr(state, TAL, type, attr);
+      break;
+    }
+
     MS_TYPE_ATTRS_CASELIST:
       if (!handleMSPointerTypeQualifierAttr(state, attr, type))
         attr.setUsedAsTypeAttr();
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index d24d8d5335e28d..aa34300dfa58db 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -7435,7 +7435,8 @@ QualType TreeTransform<Derived>::TransformAttributedType(TypeLocBuilder &TLB,
 
     result = SemaRef.Context.getAttributedType(TL.getAttrKind(),
                                                modifiedType,
-                                               equivalentType);
+                                               equivalentType,
+                                               TL.getAttr());
   }
 
   AttributedTypeLoc newTL = TLB.push<AttributedTypeLoc>(result);
diff --git a/clang/test/AST/attr-swift_attr.m b/clang/test/AST/attr-swift_attr.m
index 6ea6775aa5a9a3..6888745fe95d41 100644
--- a/clang/test/AST/attr-swift_attr.m
+++ b/clang/test/AST/attr-swift_attr.m
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -ast-dump %s | FileCheck %s
+// RUN: %clang_cc1 -fblocks -ast-dump %s | FileCheck %s
 
 __attribute__((swift_attr("@actor")))
 @interface View
@@ -14,3 +14,48 @@ @interface Contact
 
 // CHECK-LABEL: InterfaceDecl {{.*}} Contact
 // CHECK-NEXT: SwiftAttrAttr {{.*}} "@sendable"
+
+#define SWIFT_SENDABLE __attribute__((swift_attr("@Sendable")))
+
+ at interface InTypeContext
+- (nullable id)test:(nullable SWIFT_SENDABLE id)obj SWIFT_SENDABLE;
+ at end
+
+// CHECK-LABEL: InterfaceDecl {{.*}} InTypeContext
+// CHECK-NEXT: MethodDecl {{.*}} - test: 'id _Nullable':'id'
+// CHECK-NEXT: ParmVarDecl {{.*}} obj 'SWIFT_SENDABLE id _Nullable':'id'
+// CHECK-NEXT: SwiftAttrAttr {{.*}} "@Sendable"
+
+ at interface Generic<T: SWIFT_SENDABLE id>
+ at end
+
+// CHECK-LABEL: InterfaceDecl {{.*}} Generic
+// CHECK-NEXT: TypeParamDecl {{.*}} T bounded 'SWIFT_SENDABLE id':'id'
+
+typedef SWIFT_SENDABLE Generic<id> Alias;
+
+// CHECK-LABEL: TypedefDecl {{.*}} Alias 'Generic<id>'
+// CHECK-NEXT: ObjectType {{.*}} 'Generic<id>'
+// CHECK-NEXT: SwiftAttrAttr {{.*}} "@Sendable"
+
+SWIFT_SENDABLE
+typedef struct {
+  void *ptr;
+} SendableStruct;
+
+// CHECK-LABEL: TypedefDecl {{.*}} SendableStruct 'struct SendableStruct':'SendableStruct'
+// CHECK: SwiftAttrAttr {{.*}} "@Sendable"
+
+ at interface TestAttrPlacementInBlock1
+-(void) withHandler: (void (SWIFT_SENDABLE ^)(id)) handler;
+ at end
+
+// CHECK-LABEL: ObjCInterfaceDecl {{.*}} TestAttrPlacementInBlock1
+// CHECK: handler 'SWIFT_SENDABLE void (^)(id)':'void (^)(id)'
+
+ at interface TestAttrPlacementInBlock2
+-(void) withHandler: (void (^ SWIFT_SENDABLE)(id)) handler;
+ at end
+
+// CHECK-LABEL: ObjCInterfaceDecl {{.*}} TestAttrPlacementInBlock2
+// CHECK: handler 'SWIFT_SENDABLE void (^)(id)':'void (^)(id)'
diff --git a/clang/test/SemaObjC/validate-attr-swift_attr.m b/clang/test/SemaObjC/validate-attr-swift_attr.m
index 2c73b0a892722c..c7511098804a7d 100644
--- a/clang/test/SemaObjC/validate-attr-swift_attr.m
+++ b/clang/test/SemaObjC/validate-attr-swift_attr.m
@@ -9,3 +9,7 @@ @interface I
 __attribute__((swift_attr(1)))
 @interface J
 @end
+
+ at interface Error<T: __attribute__((swift_attr(1))) id>
+// expected-error at -1 {{expected string literal as argument of 'swift_attr' attribute}}
+ at end



More information about the cfe-commits mailing list