[clang] [clang/AST] Make it possible to use SwiftAttr in type context (PR #108631)

Pavel Yaskevich via cfe-commits cfe-commits at lists.llvm.org
Fri Sep 20 16:55:40 PDT 2024


https://github.com/xedin updated https://github.com/llvm/llvm-project/pull/108631

>From bd447428181ec9ab38679625fd3b3b422eb18446 Mon Sep 17 00:00:00 2001
From: Pavel Yaskevich <xedin at apache.org>
Date: Wed, 20 Dec 2023 14:04:22 -0800
Subject: [PATCH] [clang/AST] Make it possible to use SwiftAttr in type context

Swift ClangImporter now supports concurrency annotations on imported declarations
and their parameters/results, to make it possible to use imported APIs in Swift
safely there has to be a way to annotate individual parameters and result types
with relevant attributes that indicate that e.g. a block is called on a particular
actor or it accepts a `Sendable` parameter.

To faciliate that `SwiftAttr` is switched from `InheritableAttr` which is a declaration
attribute to `DeclOrTypeAttr`. To support this attribute in type context we need access
to its "Attribute" argument which requires `AttributedType` to be extended to include
`Attr *` when available instead of just `attr::Kind` otherwise it won't be possible to
determine what attribute should be imported.
---
 clang/include/clang/AST/ASTContext.h          |  7 ++
 clang/include/clang/AST/PropertiesBase.td     |  1 +
 clang/include/clang/AST/Type.h                | 42 +++++-------
 clang/include/clang/AST/TypeProperties.td     |  8 ++-
 clang/include/clang/Basic/Attr.td             |  2 +-
 clang/include/clang/Basic/AttrDocs.td         |  4 +-
 .../clang/Serialization/ASTRecordWriter.h     |  2 +
 clang/lib/AST/ASTContext.cpp                  | 51 +++++++++++---
 clang/lib/AST/ASTDiagnostic.cpp               |  6 +-
 clang/lib/AST/ASTImporter.cpp                 |  5 +-
 clang/lib/AST/Type.cpp                        | 20 +++++-
 clang/lib/AST/TypePrinter.cpp                 |  9 +++
 clang/lib/Sema/SemaDecl.cpp                   |  4 +-
 clang/lib/Sema/SemaDeclObjC.cpp               |  4 +-
 clang/lib/Sema/SemaExpr.cpp                   |  3 +-
 clang/lib/Sema/SemaExprObjC.cpp               | 29 +++-----
 clang/lib/Sema/SemaObjCProperty.cpp           |  4 +-
 clang/lib/Sema/SemaSwift.cpp                  |  7 +-
 clang/lib/Sema/SemaType.cpp                   | 67 +++++++++++++++++--
 clang/lib/Sema/TreeTransform.h                |  3 +-
 clang/test/AST/attr-swift_attr.m              | 31 +++++++++
 .../test/SemaObjC/validate-attr-swift_attr.m  |  4 ++
 22 files changed, 229 insertions(+), 84 deletions(-)

diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h
index 168bdca3c880b2..f49bb7afce2336 100644
--- a/clang/include/clang/AST/ASTContext.h
+++ b/clang/include/clang/AST/ASTContext.h
@@ -1677,8 +1677,15 @@ class ASTContext : public RefCountedBase<ASTContext> {
   QualType getInjectedClassNameType(CXXRecordDecl *Decl, QualType TST) const;
 
   QualType getAttributedType(attr::Kind attrKind, QualType modifiedType,
+                             QualType equivalentType,
+                             const Attr *attr = nullptr) const;
+
+  QualType getAttributedType(const Attr *attr, QualType modifiedType,
                              QualType equivalentType) const;
 
+  QualType getAttributedType(NullabilityKind nullability, QualType modifiedType,
+                             QualType equivalentType);
+
   QualType getBTFTagAttributedType(const BTFTypeTagAttr *BTFAttr,
                                    QualType Wrapped);
 
diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td
index 9b934b20cf2559..0d1712a1fe15fd 100644
--- a/clang/include/clang/AST/PropertiesBase.td
+++ b/clang/include/clang/AST/PropertiesBase.td
@@ -76,6 +76,7 @@ def APValue : PropertyType { let PassByReference = 1; }
 def APValueKind : EnumPropertyType<"APValue::ValueKind">;
 def ArraySizeModifier : EnumPropertyType<"ArraySizeModifier">;
 def AttrKind : EnumPropertyType<"attr::Kind">;
+def Attr : PropertyType<"const Attr *">;
 def AutoTypeKeyword : EnumPropertyType;
 def Bool : PropertyType<"bool">;
 def BuiltinTypeKind : EnumPropertyType<"BuiltinType::Kind">;
diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h
index ef36a73716454f..860f5fce150449 100644
--- a/clang/include/clang/AST/Type.h
+++ b/clang/include/clang/AST/Type.h
@@ -68,6 +68,7 @@ class ValueDecl;
 class TagDecl;
 class TemplateParameterList;
 class Type;
+class Attr;
 
 enum {
   TypeAlignmentInBits = 4,
@@ -6037,21 +6038,29 @@ class AttributedType : public Type, public llvm::FoldingSetNode {
 private:
   friend class ASTContext; // ASTContext creates these
 
+  const Attr *Attribute;
+
   QualType ModifiedType;
   QualType EquivalentType;
 
   AttributedType(QualType canon, attr::Kind attrKind, QualType modified,
                  QualType equivalent)
-      : Type(Attributed, canon, equivalent->getDependence()),
-        ModifiedType(modified), EquivalentType(equivalent) {
-    AttributedTypeBits.AttrKind = attrKind;
-  }
+      : AttributedType(canon, attrKind, nullptr, modified, equivalent) {}
+
+  AttributedType(QualType canon, const Attr *attr, QualType modified,
+                 QualType equivalent);
+
+private:
+  AttributedType(QualType canon, attr::Kind attrKind, const Attr *attr,
+                 QualType modified, QualType equivalent);
 
 public:
   Kind getAttrKind() const {
     return static_cast<Kind>(AttributedTypeBits.AttrKind);
   }
 
+  const Attr *getAttr() const { return Attribute; }
+
   QualType getModifiedType() const { return ModifiedType; }
   QualType getEquivalentType() const { return EquivalentType; }
 
@@ -6083,25 +6092,6 @@ class AttributedType : public Type, public llvm::FoldingSetNode {
 
   std::optional<NullabilityKind> getImmediateNullability() const;
 
-  /// Retrieve the attribute kind corresponding to the given
-  /// nullability kind.
-  static Kind getNullabilityAttrKind(NullabilityKind kind) {
-    switch (kind) {
-    case NullabilityKind::NonNull:
-      return attr::TypeNonNull;
-
-    case NullabilityKind::Nullable:
-      return attr::TypeNullable;
-
-    case NullabilityKind::NullableResult:
-      return attr::TypeNullableResult;
-
-    case NullabilityKind::Unspecified:
-      return attr::TypeNullUnspecified;
-    }
-    llvm_unreachable("Unknown nullability kind.");
-  }
-
   /// Strip off the top-level nullability annotation on the given
   /// type, if it's there.
   ///
@@ -6114,14 +6104,16 @@ class AttributedType : public Type, public llvm::FoldingSetNode {
   static std::optional<NullabilityKind> stripOuterNullability(QualType &T);
 
   void Profile(llvm::FoldingSetNodeID &ID) {
-    Profile(ID, getAttrKind(), ModifiedType, EquivalentType);
+    Profile(ID, getAttrKind(), ModifiedType, EquivalentType, Attribute);
   }
 
   static void Profile(llvm::FoldingSetNodeID &ID, Kind attrKind,
-                      QualType modified, QualType equivalent) {
+                      QualType modified, QualType equivalent,
+                      const Attr *attr) {
     ID.AddInteger(attrKind);
     ID.AddPointer(modified.getAsOpaquePtr());
     ID.AddPointer(equivalent.getAsOpaquePtr());
+    ID.AddPointer(attr);
   }
 
   static bool classof(const Type *T) {
diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td
index 539a344cb0b690..82605b18fbde6c 100644
--- a/clang/include/clang/AST/TypeProperties.td
+++ b/clang/include/clang/AST/TypeProperties.td
@@ -668,12 +668,16 @@ let Class = AttributedType in {
   def : Property<"equivalentType", QualType> {
     let Read = [{ node->getEquivalentType() }];
   }
-  def : Property<"attribute", AttrKind> {
+  def : Property<"attrKind", AttrKind> {
     let Read = [{ node->getAttrKind() }];
   }
+  def : Property<"attribute", Attr> {
+    let Read = [{ node->getAttr() }];
+  }
 
   def : Creator<[{
-    return ctx.getAttributedType(attribute, modifiedType, equivalentType);
+    return ctx.getAttributedType(attrKind, modifiedType,
+                                 equivalentType, attribute);
   }]>;
 }
 
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 70fad60d4edbb5..71f626845b0108 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -2834,7 +2834,7 @@ def SwiftAsyncName : InheritableAttr {
   let Documentation = [SwiftAsyncNameDocs];
 }
 
-def SwiftAttr : InheritableAttr {
+def SwiftAttr : DeclOrTypeAttr {
   let Spellings = [GNU<"swift_attr">];
   let Args = [StringArgument<"Attribute">];
   let Documentation = [SwiftAttrDocs];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 546e5100b79dd9..0bb51585d407bf 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -4489,8 +4489,8 @@ def SwiftAttrDocs : Documentation {
   let Heading = "swift_attr";
   let Content = [{
 The ``swift_attr`` provides a Swift-specific annotation for the declaration
-to which the attribute appertains to. It can be used on any declaration
-in Clang. This kind of annotation is ignored by Clang as it doesn't have any
+or type to which the attribute appertains to. It can be used on any declaration
+or type in Clang. This kind of annotation is ignored by Clang as it doesn't have any
 semantic meaning in languages supported by Clang. The Swift compiler can
 interpret these annotations according to its own rules when importing C or
 Objective-C declarations.
diff --git a/clang/include/clang/Serialization/ASTRecordWriter.h b/clang/include/clang/Serialization/ASTRecordWriter.h
index 0c8ac75fc40f40..d6090ba1a6c690 100644
--- a/clang/include/clang/Serialization/ASTRecordWriter.h
+++ b/clang/include/clang/Serialization/ASTRecordWriter.h
@@ -128,6 +128,8 @@ class ASTRecordWriter
     AddStmt(const_cast<Stmt*>(S));
   }
 
+  void writeAttr(const Attr *A) { AddAttr(A); }
+
   /// Write an BTFTypeTagAttr object.
   void writeBTFTypeTagAttr(const BTFTypeTagAttr *A) { AddAttr(A); }
 
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 7cc69ca4a8a814..107a3cf5b6ab39 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -3566,11 +3566,13 @@ QualType ASTContext::getFunctionTypeWithExceptionSpec(
         MQT->getMacroIdentifier());
 
   // Might have a calling-convention attribute.
-  if (const auto *AT = dyn_cast<AttributedType>(Orig))
+  if (const auto *AT = dyn_cast<AttributedType>(Orig)) {
     return getAttributedType(
         AT->getAttrKind(),
         getFunctionTypeWithExceptionSpec(AT->getModifiedType(), ESI),
-        getFunctionTypeWithExceptionSpec(AT->getEquivalentType(), ESI));
+        getFunctionTypeWithExceptionSpec(AT->getEquivalentType(), ESI),
+        AT->getAttr());
+  }
 
   // Anything else must be a function type. Rebuild it with the new exception
   // specification.
@@ -5125,17 +5127,20 @@ QualType ASTContext::getUnresolvedUsingType(
 
 QualType ASTContext::getAttributedType(attr::Kind attrKind,
                                        QualType modifiedType,
-                                       QualType equivalentType) const {
+                                       QualType equivalentType,
+                                       const Attr *attr) const {
   llvm::FoldingSetNodeID id;
-  AttributedType::Profile(id, attrKind, modifiedType, equivalentType);
+  AttributedType::Profile(id, attrKind, modifiedType, equivalentType, attr);
 
   void *insertPos = nullptr;
   AttributedType *type = AttributedTypes.FindNodeOrInsertPos(id, insertPos);
   if (type) return QualType(type, 0);
 
+  assert(!attr || attr->getKind() == attrKind);
+
   QualType canon = getCanonicalType(equivalentType);
-  type = new (*this, alignof(AttributedType))
-      AttributedType(canon, attrKind, modifiedType, equivalentType);
+	type = new (*this, alignof(AttributedType))
+      AttributedType(canon, attrKind, attr, modifiedType, equivalentType);
 
   Types.push_back(type);
   AttributedTypes.InsertNode(type, insertPos);
@@ -5143,6 +5148,33 @@ QualType ASTContext::getAttributedType(attr::Kind attrKind,
   return QualType(type, 0);
 }
 
+QualType ASTContext::getAttributedType(const Attr *attr, QualType modifiedType,
+                                       QualType equivalentType) const {
+  return getAttributedType(attr->getKind(), modifiedType, equivalentType, attr);
+}
+
+QualType ASTContext::getAttributedType(NullabilityKind nullability,
+                                       QualType modifiedType,
+                                       QualType equivalentType) {
+  switch (nullability) {
+  case NullabilityKind::NonNull:
+    return getAttributedType(attr::TypeNonNull, modifiedType, equivalentType);
+
+  case NullabilityKind::Nullable:
+    return getAttributedType(attr::TypeNullable, modifiedType, equivalentType);
+
+  case NullabilityKind::NullableResult:
+    return getAttributedType(attr::TypeNullableResult, modifiedType,
+                             equivalentType);
+
+  case NullabilityKind::Unspecified:
+    return getAttributedType(attr::TypeNullUnspecified, modifiedType,
+                             equivalentType);
+  }
+
+  llvm_unreachable("Unknown nullability kind");
+}
+
 QualType ASTContext::getBTFTagAttributedType(const BTFTypeTagAttr *BTFAttr,
                                              QualType Wrapped) {
   llvm::FoldingSetNodeID ID;
@@ -7474,8 +7506,8 @@ QualType ASTContext::getArrayDecayedType(QualType Ty) const {
 
   // int x[_Nullable] -> int * _Nullable
   if (auto Nullability = Ty->getNullability()) {
-    Result = const_cast<ASTContext *>(this)->getAttributedType(
-        AttributedType::getNullabilityAttrKind(*Nullability), Result, Result);
+    Result = const_cast<ASTContext *>(this)->getAttributedType(*Nullability,
+                                                               Result, Result);
   }
   return Result;
 }
@@ -13693,7 +13725,8 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X,
       return QualType();
     // FIXME: It's inefficient to have to unify the modified types.
     return Ctx.getAttributedType(Kind, Ctx.getCommonSugaredType(MX, MY),
-                                 Ctx.getQualifiedType(Underlying));
+                                 Ctx.getQualifiedType(Underlying),
+                                 AX->getAttr());
   }
   case Type::BTFTagAttributed: {
     const auto *BX = cast<BTFTagAttributedType>(X);
diff --git a/clang/lib/AST/ASTDiagnostic.cpp b/clang/lib/AST/ASTDiagnostic.cpp
index 15c3efe4212719..4f677b60e60dae 100644
--- a/clang/lib/AST/ASTDiagnostic.cpp
+++ b/clang/lib/AST/ASTDiagnostic.cpp
@@ -85,8 +85,7 @@ QualType clang::desugarForDiagnostic(ASTContext &Context, QualType QT,
       QualType SugarRT = FT->getReturnType();
       QualType RT = desugarForDiagnostic(Context, SugarRT, DesugarReturn);
       if (auto nullability = AttributedType::stripOuterNullability(SugarRT)) {
-        RT = Context.getAttributedType(
-            AttributedType::getNullabilityAttrKind(*nullability), RT, RT);
+        RT = Context.getAttributedType(*nullability, RT, RT);
       }
 
       bool DesugarArgument = false;
@@ -97,8 +96,7 @@ QualType clang::desugarForDiagnostic(ASTContext &Context, QualType QT,
           QualType PT = desugarForDiagnostic(Context, SugarPT, DesugarArgument);
           if (auto nullability =
                   AttributedType::stripOuterNullability(SugarPT)) {
-            PT = Context.getAttributedType(
-                AttributedType::getNullabilityAttrKind(*nullability), PT, PT);
+            PT = Context.getAttributedType(*nullability, PT, PT);
           }
           Args.push_back(PT);
         }
diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp
index c2fb7dddcfc637..1be755e97f716d 100644
--- a/clang/lib/AST/ASTImporter.cpp
+++ b/clang/lib/AST/ASTImporter.cpp
@@ -1580,8 +1580,9 @@ ExpectedType ASTNodeImporter::VisitAttributedType(const AttributedType *T) {
   if (!ToEquivalentTypeOrErr)
     return ToEquivalentTypeOrErr.takeError();
 
-  return Importer.getToContext().getAttributedType(T->getAttrKind(),
-      *ToModifiedTypeOrErr, *ToEquivalentTypeOrErr);
+  return Importer.getToContext().getAttributedType(
+      T->getAttrKind(), *ToModifiedTypeOrErr, *ToEquivalentTypeOrErr,
+      T->getAttr());
 }
 
 ExpectedType
diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp
index a55e6c8bf02611..54dd6654f5a668 100644
--- a/clang/lib/AST/Type.cpp
+++ b/clang/lib/AST/Type.cpp
@@ -1241,8 +1241,8 @@ struct SimpleTransformVisitor : public TypeVisitor<Derived, QualType> {
           == T->getEquivalentType().getAsOpaquePtr())
       return QualType(T, 0);
 
-    return Ctx.getAttributedType(T->getAttrKind(), modifiedType,
-                                 equivalentType);
+    return Ctx.getAttributedType(T->getAttrKind(), modifiedType, equivalentType,
+                                 T->getAttr());
   }
 
   QualType VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *T) {
@@ -1545,7 +1545,8 @@ struct SubstObjCTypeArgsVisitor
 
     // Rebuild the attributed type.
     return Ctx.getAttributedType(newAttrType->getAttrKind(),
-                                 newAttrType->getModifiedType(), newEquivType);
+                                 newAttrType->getModifiedType(), newEquivType,
+                                 newAttrType->getAttr());
   }
 };
 
@@ -4104,6 +4105,19 @@ bool RecordType::hasConstFields() const {
   return false;
 }
 
+AttributedType::AttributedType(QualType canon, const Attr *attr,
+                               QualType modified, QualType equivalent)
+    : AttributedType(canon, attr->getKind(), attr, modified, equivalent) {}
+
+AttributedType::AttributedType(QualType canon, attr::Kind attrKind,
+                               const Attr *attr, QualType modified,
+                               QualType equivalent)
+    : Type(Attributed, canon, equivalent->getDependence()), Attribute(attr),
+      ModifiedType(modified), EquivalentType(equivalent) {
+  AttributedTypeBits.AttrKind = attrKind;
+  assert(!attr || attr->getKind() == attrKind);
+}
+
 bool AttributedType::isQualifier() const {
   // FIXME: Generate this with TableGen.
   switch (getAttrKind()) {
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index be627a6242eb40..8cc4c39f7b80e2 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -1932,6 +1932,14 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
     return;
   }
 
+  if (T->getAttrKind() == attr::SwiftAttr) {
+    if (auto *swiftAttr = dyn_cast_or_null<SwiftAttrAttr>(T->getAttr())) {
+      OS << " __attribute__((swift_attr(\"" << swiftAttr->getAttribute()
+         << "\")))";
+    }
+    return;
+  }
+
   OS << " __attribute__((";
   switch (T->getAttrKind()) {
 #define TYPE_ATTR(NAME)
@@ -1990,6 +1998,7 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
   case attr::NonAllocating:
   case attr::Blocking:
   case attr::Allocating:
+  case attr::SwiftAttr:
     llvm_unreachable("This attribute should have been handled already");
 
   case attr::NSReturnsRetained:
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 8557c25b93a8da..4bb3f2a0f302cb 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -3329,9 +3329,7 @@ static void mergeParamDeclTypes(ParmVarDecl *NewParam,
       }
     } else {
       QualType NewT = NewParam->getType();
-      NewT = S.Context.getAttributedType(
-                         AttributedType::getNullabilityAttrKind(*Oldnullability),
-                         NewT, NewT);
+      NewT = S.Context.getAttributedType(*Oldnullability, NewT, NewT);
       NewParam->setType(NewT);
     }
   }
diff --git a/clang/lib/Sema/SemaDeclObjC.cpp b/clang/lib/Sema/SemaDeclObjC.cpp
index 807453400abdd0..229cf87de52f12 100644
--- a/clang/lib/Sema/SemaDeclObjC.cpp
+++ b/clang/lib/Sema/SemaDeclObjC.cpp
@@ -4577,9 +4577,7 @@ static QualType mergeTypeNullabilityForRedecl(Sema &S, SourceLocation loc,
     return type;
 
   // Otherwise, provide the result with the same nullability.
-  return S.Context.getAttributedType(
-           AttributedType::getNullabilityAttrKind(*prevNullability),
-           type, type);
+  return S.Context.getAttributedType(*prevNullability, type, type);
 }
 
 /// Merge information from the declaration of a method in the \@interface
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index 8f3e15cc9a9bb7..958ef034d598c0 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -8738,8 +8738,7 @@ static QualType computeConditionalNullability(QualType ResTy, bool IsBin,
     ResTy = ResTy.getSingleStepDesugaredType(Ctx);
 
   // Create a new AttributedType with the new nullability kind.
-  auto NewAttr = AttributedType::getNullabilityAttrKind(MergedKind);
-  return Ctx.getAttributedType(NewAttr, ResTy, ResTy);
+  return Ctx.getAttributedType(MergedKind, ResTy, ResTy);
 }
 
 ExprResult Sema::ActOnConditionalOp(SourceLocation QuestionLoc,
diff --git a/clang/lib/Sema/SemaExprObjC.cpp b/clang/lib/Sema/SemaExprObjC.cpp
index 2f914ddc22a38b..5a2e1fe4f16f6e 100644
--- a/clang/lib/Sema/SemaExprObjC.cpp
+++ b/clang/lib/Sema/SemaExprObjC.cpp
@@ -551,9 +551,7 @@ ExprResult SemaObjC::BuildObjCBoxedExpr(SourceRange SR, Expr *ValueExpr) {
             const llvm::UTF8 *StrEnd = Str.bytes_end();
             // Check that this is a valid UTF-8 string.
             if (llvm::isLegalUTF8String(&StrBegin, StrEnd)) {
-              BoxedType = Context.getAttributedType(
-                  AttributedType::getNullabilityAttrKind(
-                      NullabilityKind::NonNull),
+              BoxedType = Context.getAttributedType(NullabilityKind::NonNull,
                   NSStringPointer, NSStringPointer);
               return new (Context) ObjCBoxedExpr(CE, BoxedType, nullptr, SR);
             }
@@ -605,9 +603,8 @@ ExprResult SemaObjC::BuildObjCBoxedExpr(SourceRange SR, Expr *ValueExpr) {
       std::optional<NullabilityKind> Nullability =
           BoxingMethod->getReturnType()->getNullability();
       if (Nullability)
-        BoxedType = Context.getAttributedType(
-            AttributedType::getNullabilityAttrKind(*Nullability), BoxedType,
-            BoxedType);
+        BoxedType =
+            Context.getAttributedType(*Nullability, BoxedType, BoxedType);
     }
   } else if (ValueType->isBuiltinType()) {
     // The other types we support are numeric, char and BOOL/bool. We could also
@@ -1444,10 +1441,8 @@ static QualType stripObjCInstanceType(ASTContext &Context, QualType T) {
   QualType origType = T;
   if (auto nullability = AttributedType::stripOuterNullability(T)) {
     if (T == Context.getObjCInstanceType()) {
-      return Context.getAttributedType(
-               AttributedType::getNullabilityAttrKind(*nullability),
-               Context.getObjCIdType(),
-               Context.getObjCIdType());
+      return Context.getAttributedType(*nullability, Context.getObjCIdType(),
+                                       Context.getObjCIdType());
     }
 
     return origType;
@@ -1485,10 +1480,7 @@ static QualType getBaseMessageSendResultType(Sema &S,
       (void)AttributedType::stripOuterNullability(type);
 
       // Form a new attributed type using the method result type's nullability.
-      return Context.getAttributedType(
-               AttributedType::getNullabilityAttrKind(*nullability),
-               type,
-               type);
+      return Context.getAttributedType(*nullability, type, type);
     }
 
     return type;
@@ -1559,9 +1551,8 @@ QualType SemaObjC::getMessageSendResultType(const Expr *Receiver,
         QualType NewResultType = Context.getObjCObjectPointerType(
             Context.getObjCInterfaceType(MD->getClassInterface()));
         if (auto Nullability = resultType->getNullability())
-          NewResultType = Context.getAttributedType(
-              AttributedType::getNullabilityAttrKind(*Nullability),
-              NewResultType, NewResultType);
+          NewResultType = Context.getAttributedType(*Nullability, NewResultType,
+                                                    NewResultType);
         return NewResultType;
       }
     }
@@ -1623,9 +1614,7 @@ QualType SemaObjC::getMessageSendResultType(const Expr *Receiver,
   if (newResultNullabilityIdx > 0) {
     auto newNullability
       = static_cast<NullabilityKind>(newResultNullabilityIdx-1);
-    return Context.getAttributedType(
-             AttributedType::getNullabilityAttrKind(newNullability),
-             resultType, resultType);
+    return Context.getAttributedType(newNullability, resultType, resultType);
   }
 
   return resultType;
diff --git a/clang/lib/Sema/SemaObjCProperty.cpp b/clang/lib/Sema/SemaObjCProperty.cpp
index 031f2a6af87744..f1495125b19798 100644
--- a/clang/lib/Sema/SemaObjCProperty.cpp
+++ b/clang/lib/Sema/SemaObjCProperty.cpp
@@ -2460,7 +2460,7 @@ void SemaObjC::ProcessPropertyDecl(ObjCPropertyDecl *property) {
       QualType modifiedTy = resultTy;
       if (auto nullability = AttributedType::stripOuterNullability(modifiedTy)) {
         if (*nullability == NullabilityKind::Unspecified)
-          resultTy = Context.getAttributedType(attr::TypeNonNull,
+          resultTy = Context.getAttributedType(NullabilityKind::NonNull,
                                                modifiedTy, modifiedTy);
       }
     }
@@ -2538,7 +2538,7 @@ void SemaObjC::ProcessPropertyDecl(ObjCPropertyDecl *property) {
         QualType modifiedTy = paramTy;
         if (auto nullability = AttributedType::stripOuterNullability(modifiedTy)){
           if (*nullability == NullabilityKind::Unspecified)
-            paramTy = Context.getAttributedType(attr::TypeNullable,
+            paramTy = Context.getAttributedType(NullabilityKind::Nullable,
                                                 modifiedTy, modifiedTy);
         }
       }
diff --git a/clang/lib/Sema/SemaSwift.cpp b/clang/lib/Sema/SemaSwift.cpp
index 2eebce74b5e2f8..24fdfb8e57dc34 100644
--- a/clang/lib/Sema/SemaSwift.cpp
+++ b/clang/lib/Sema/SemaSwift.cpp
@@ -73,11 +73,16 @@ static bool isValidSwiftErrorResultType(QualType Ty) {
 }
 
 void SemaSwift::handleAttrAttr(Decl *D, const ParsedAttr &AL) {
+  if (AL.isInvalid() || AL.isUsedAsTypeAttr())
+    return;
+
   // Make sure that there is a string literal as the annotation's single
   // argument.
   StringRef Str;
-  if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str))
+  if (!SemaRef.checkStringLiteralArgumentAttr(AL, 0, Str)) {
+    AL.setInvalid();
     return;
+  }
 
   D->addAttr(::new (getASTContext()) SwiftAttrAttr(getASTContext(), AL, Str));
 }
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index e627fee51b66b8..d2aaa3b138bfeb 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -289,7 +289,7 @@ namespace {
     QualType getAttributedType(Attr *A, QualType ModifiedType,
                                QualType EquivType) {
       QualType T =
-          sema.Context.getAttributedType(A->getKind(), ModifiedType, EquivType);
+          sema.Context.getAttributedType(A, ModifiedType, EquivType);
       AttrsForTypes.push_back({cast<AttributedType>(T.getTypePtr()), A});
       AttrsForTypesSorted = false;
       return T;
@@ -7147,6 +7147,60 @@ static bool HandleWebAssemblyFuncrefAttr(TypeProcessingState &State,
   return false;
 }
 
+static void HandleSwiftAttr(TypeProcessingState &State, TypeAttrLocation TAL,
+                            QualType &QT, ParsedAttr &PAttr) {
+  if (TAL == TAL_DeclName)
+    return;
+
+  Sema &S = State.getSema();
+  auto &D = State.getDeclarator();
+
+  // If the attribute appears in declaration specifiers
+  // it should be handled as a declaration attribute,
+  // unless it's associated with a type or a function
+  // prototype (i.e. appears on a parameter or result type).
+  if (State.isProcessingDeclSpec()) {
+    if (!(D.isPrototypeContext() ||
+          D.getContext() == DeclaratorContext::TypeName))
+      return;
+
+    if (auto *chunk = D.getInnermostNonParenChunk()) {
+      moveAttrFromListToList(PAttr, State.getCurrentAttributes(),
+                             const_cast<DeclaratorChunk *>(chunk)->getAttrs());
+      return;
+    }
+  }
+
+  StringRef Str;
+  if (!S.checkStringLiteralArgumentAttr(PAttr, 0, Str)) {
+    PAttr.setInvalid();
+    return;
+  }
+
+  // If the attribute as attached to a paren move it closer to
+  // the declarator. This can happen in block declarations when
+  // an attribute is placed before `^` i.e. `(__attribute__((...)) ^)`.
+  //
+  // Note that it's actually invalid to use GNU style attributes
+  // in a block but such cases are currently handled gracefully
+  // but the parser and behavior should be consistent between
+  // cases when attribute appears before/after block's result
+  // type and inside (^).
+  if (TAL == TAL_DeclChunk) {
+    auto chunkIdx = State.getCurrentChunkIndex();
+    if (chunkIdx >= 1 &&
+        D.getTypeObject(chunkIdx).Kind == DeclaratorChunk::Paren) {
+      moveAttrFromListToList(PAttr, State.getCurrentAttributes(),
+                             D.getTypeObject(chunkIdx - 1).getAttrs());
+      return;
+    }
+  }
+
+  auto *A = ::new (S.Context) SwiftAttrAttr(S.Context, PAttr, Str);
+  QT = State.getAttributedType(A, QT, QT);
+  PAttr.setUsedAsTypeAttr();
+}
+
 /// Rebuild an attributed type without the nullability attribute on it.
 static QualType rebuildAttributedTypeWithoutNullability(ASTContext &Ctx,
                                                         QualType Type) {
@@ -7163,7 +7217,8 @@ static QualType rebuildAttributedTypeWithoutNullability(ASTContext &Ctx,
       Ctx, Attributed->getModifiedType());
   assert(Modified.getTypePtr() != Attributed->getModifiedType().getTypePtr());
   return Ctx.getAttributedType(Attributed->getAttrKind(), Modified,
-                               Attributed->getEquivalentType());
+                               Attributed->getEquivalentType(),
+                               Attributed->getAttr());
 }
 
 /// Map a nullability attribute kind to a nullability kind.
@@ -7292,8 +7347,7 @@ static bool CheckNullabilityTypeSpecifier(
     Attr *A = createNullabilityAttr(S.Context, *PAttr, Nullability);
     QT = State->getAttributedType(A, QT, QT);
   } else {
-    attr::Kind attrKind = AttributedType::getNullabilityAttrKind(Nullability);
-    QT = S.Context.getAttributedType(attrKind, QT, QT);
+    QT = S.Context.getAttributedType(Nullability, QT, QT);
   }
   return false;
 }
@@ -8735,6 +8789,11 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
       break;
     }
 
+    case ParsedAttr::AT_SwiftAttr: {
+      HandleSwiftAttr(state, TAL, type, attr);
+      break;
+    }
+
     MS_TYPE_ATTRS_CASELIST:
       if (!handleMSPointerTypeQualifierAttr(state, attr, type))
         attr.setUsedAsTypeAttr();
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index ff745b3385fcd9..22e6573bbf1020 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -7413,7 +7413,8 @@ QualType TreeTransform<Derived>::TransformAttributedType(
 
     result = SemaRef.Context.getAttributedType(TL.getAttrKind(),
                                                modifiedType,
-                                               equivalentType);
+                                               equivalentType,
+                                               TL.getAttr());
   }
 
   AttributedTypeLoc newTL = TLB.push<AttributedTypeLoc>(result);
diff --git a/clang/test/AST/attr-swift_attr.m b/clang/test/AST/attr-swift_attr.m
index 6ea6775aa5a9a3..4406112928c7d8 100644
--- a/clang/test/AST/attr-swift_attr.m
+++ b/clang/test/AST/attr-swift_attr.m
@@ -14,3 +14,34 @@ @interface Contact
 
 // CHECK-LABEL: InterfaceDecl {{.*}} Contact
 // CHECK-NEXT: SwiftAttrAttr {{.*}} "@sendable"
+
+#define SWIFT_SENDABLE __attribute__((swift_attr("@Sendable")))
+
+ at interface InTypeContext
+- (nullable id)test:(nullable SWIFT_SENDABLE id)obj SWIFT_SENDABLE;
+ at end
+
+// CHECK-LABEL: InterfaceDecl {{.*}} InTypeContext
+// CHECK-NEXT: MethodDecl {{.*}} - test: 'id _Nullable':'id'
+// CHECK-NEXT: ParmVarDecl {{.*}} obj 'SWIFT_SENDABLE id _Nullable':'id'
+// CHECK-NEXT: SwiftAttrAttr {{.*}} "@Sendable"
+
+ at interface Generic<T: SWIFT_SENDABLE id>
+ at end
+
+// CHECK-LABEL: InterfaceDecl {{.*}} Generic
+// CHECK-NEXT: TypeParamDecl {{.*}} T bounded 'SWIFT_SENDABLE id':'id'
+
+typedef SWIFT_SENDABLE Generic<id> Alias;
+
+// CHECK-LABEL: TypedefDecl {{.*}} Alias 'Generic<id>'
+// CHECK-NEXT: ObjectType {{.*}} 'Generic<id>'
+// CHECK-NEXT: SwiftAttrAttr {{.*}} "@Sendable"
+
+SWIFT_SENDABLE
+typedef struct {
+  void *ptr;
+} SendableStruct;
+
+// CHECK-LABEL: TypedefDecl {{.*}} SendableStruct 'struct SendableStruct':'SendableStruct'
+// CHECK: SwiftAttrAttr {{.*}} "@Sendable"
diff --git a/clang/test/SemaObjC/validate-attr-swift_attr.m b/clang/test/SemaObjC/validate-attr-swift_attr.m
index 2c73b0a892722c..c7511098804a7d 100644
--- a/clang/test/SemaObjC/validate-attr-swift_attr.m
+++ b/clang/test/SemaObjC/validate-attr-swift_attr.m
@@ -9,3 +9,7 @@ @interface I
 __attribute__((swift_attr(1)))
 @interface J
 @end
+
+ at interface Error<T: __attribute__((swift_attr(1))) id>
+// expected-error at -1 {{expected string literal as argument of 'swift_attr' attribute}}
+ at end



More information about the cfe-commits mailing list