[clang] [APINotes] Upstream Sema logic to apply API Notes to decls (PR #73017)

Egor Zhdan via cfe-commits cfe-commits at lists.llvm.org
Tue Jan 16 08:44:04 PST 2024


https://github.com/egorzhdan updated https://github.com/llvm/llvm-project/pull/73017

>From 447aabfbf20c8d286caf6ad3f2e30da116592860 Mon Sep 17 00:00:00 2001
From: Egor Zhdan <e_zhdan at apple.com>
Date: Mon, 20 Nov 2023 18:03:18 +0000
Subject: [PATCH 1/3] [APINotes] Upstream attributes that are created
 implicitly from APINotes

---
 clang/include/clang/Basic/Attr.td         | 46 ++++++++++++++++++++++-
 clang/lib/Sema/SemaDeclAttr.cpp           | 34 ++++++++++-------
 clang/lib/Serialization/ASTReaderDecl.cpp |  2 +
 clang/test/Sema/ns_error_enum.m           |  3 +-
 clang/utils/TableGen/ClangAttrEmitter.cpp | 26 ++++++++++++-
 5 files changed, 93 insertions(+), 18 deletions(-)

diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index b9ec720dd9e199..78a9229aeaf081 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -306,6 +306,9 @@ class VariadicEnumArgument<string name, string type, list<string> values,
   bit IsExternalType = isExternalType;
 }
 
+// Represents an attribute wrapped by another attribute.
+class WrappedAttr<string name, bit opt = 0> : Argument<name, opt>;
+
 // This handles one spelling of an attribute.
 class Spelling<string name, string variety, int version = 1> {
   string Name = name;
@@ -2291,7 +2294,7 @@ def ObjCBridgeRelated : InheritableAttr {
 def NSErrorDomain : InheritableAttr {
   let Spellings = [GNU<"ns_error_domain">];
   let Subjects = SubjectList<[Enum], ErrorDiag>;
-  let Args = [DeclArgument<Var, "ErrorDomain">];
+  let Args = [IdentifierArgument<"ErrorDomain">];
   let Documentation = [NSErrorDomainDocs];
 }
 
@@ -2648,6 +2651,22 @@ def SwiftError : InheritableAttr {
   let Documentation = [SwiftErrorDocs];
 }
 
+def SwiftImportAsNonGeneric : InheritableAttr {
+  // This attribute has no spellings as it is only ever created implicitly
+  // from API notes.
+  let Spellings = [];
+  let SemaHandler = 0;
+  let Documentation = [InternalOnly];
+}
+
+def SwiftImportPropertyAsAccessors : InheritableAttr {
+  // This attribute has no spellings as it is only ever created implicitly
+  // from API notes.
+  let Spellings = [];
+  let SemaHandler = 0;
+  let Documentation = [InternalOnly];
+}
+
 def SwiftName : InheritableAttr {
   let Spellings = [GNU<"swift_name">];
   let Args = [StringArgument<"Name">];
@@ -2669,6 +2688,31 @@ def SwiftPrivate : InheritableAttr {
   let SimpleHandler = 1;
 }
 
+def SwiftVersionedAddition : Attr {
+  // This attribute has no spellings as it is only ever created implicitly
+  // from API notes.
+  let Spellings = [];
+  let Args = [VersionArgument<"Version">, WrappedAttr<"AdditionalAttr">,
+              BoolArgument<"IsReplacedByActive">];
+  let SemaHandler = 0;
+  let Documentation = [InternalOnly];
+}
+
+def SwiftVersionedRemoval : Attr {
+  // This attribute has no spellings as it is only ever created implicitly
+  // from API notes.
+  let Spellings = [];
+  let Args = [VersionArgument<"Version">, UnsignedArgument<"RawKind">,
+              BoolArgument<"IsReplacedByActive">];
+  let SemaHandler = 0;
+  let Documentation = [InternalOnly];
+  let AdditionalMembers = [{
+    attr::Kind getAttrKindToRemove() const {
+      return static_cast<attr::Kind>(getRawKind());
+    }
+  }];
+}
+
 def NoDeref : TypeAttr {
   let Spellings = [Clang<"noderef">];
   let Documentation = [NoDerefDocs];
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 7e6881049d8d95..a482919356e1bc 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -6236,29 +6236,35 @@ static void handleObjCRequiresSuperAttr(Sema &S, Decl *D,
   D->addAttr(::new (S.Context) ObjCRequiresSuperAttr(S.Context, Attrs));
 }
 
-static void handleNSErrorDomain(Sema &S, Decl *D, const ParsedAttr &AL) {
-  auto *E = AL.getArgAsExpr(0);
-  auto Loc = E ? E->getBeginLoc() : AL.getLoc();
-
-  auto *DRE = dyn_cast<DeclRefExpr>(AL.getArgAsExpr(0));
-  if (!DRE) {
-    S.Diag(Loc, diag::err_nserrordomain_invalid_decl) << 0;
+static void handleNSErrorDomain(Sema &S, Decl *D, const ParsedAttr &Attr) {
+  if (!isa<TagDecl>(D)) {
+    S.Diag(D->getBeginLoc(), diag::err_nserrordomain_invalid_decl) << 0;
     return;
   }
 
-  auto *VD = dyn_cast<VarDecl>(DRE->getDecl());
-  if (!VD) {
-    S.Diag(Loc, diag::err_nserrordomain_invalid_decl) << 1 << DRE->getDecl();
+  IdentifierLoc *IdentLoc =
+      Attr.isArgIdent(0) ? Attr.getArgAsIdent(0) : nullptr;
+  if (!IdentLoc || !IdentLoc->Ident) {
+    // Try to locate the argument directly.
+    SourceLocation Loc = Attr.getLoc();
+    if (Attr.isArgExpr(0) && Attr.getArgAsExpr(0))
+      Loc = Attr.getArgAsExpr(0)->getBeginLoc();
+
+    S.Diag(Loc, diag::err_nserrordomain_invalid_decl) << 0;
     return;
   }
 
-  if (!isNSStringType(VD->getType(), S.Context) &&
-      !isCFStringType(VD->getType(), S.Context)) {
-    S.Diag(Loc, diag::err_nserrordomain_wrong_type) << VD;
+  // Verify that the identifier is a valid decl in the C decl namespace.
+  LookupResult Result(S, DeclarationName(IdentLoc->Ident), SourceLocation(),
+                      Sema::LookupNameKind::LookupOrdinaryName);
+  if (!S.LookupName(Result, S.TUScope) || !Result.getAsSingle<VarDecl>()) {
+    S.Diag(IdentLoc->Loc, diag::err_nserrordomain_invalid_decl)
+        << 1 << IdentLoc->Ident;
     return;
   }
 
-  D->addAttr(::new (S.Context) NSErrorDomainAttr(S.Context, AL, VD));
+  D->addAttr(::new (S.Context)
+                 NSErrorDomainAttr(S.Context, Attr, IdentLoc->Ident));
 }
 
 static void handleObjCBridgeAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
diff --git a/clang/lib/Serialization/ASTReaderDecl.cpp b/clang/lib/Serialization/ASTReaderDecl.cpp
index 547eb77930b4ee..991d703978495c 100644
--- a/clang/lib/Serialization/ASTReaderDecl.cpp
+++ b/clang/lib/Serialization/ASTReaderDecl.cpp
@@ -3095,6 +3095,8 @@ class AttrReader {
 
   Expr *readExpr() { return Reader.readExpr(); }
 
+  Attr *readAttr() { return Reader.readAttr(); }
+
   std::string readString() {
     return Reader.readString();
   }
diff --git a/clang/test/Sema/ns_error_enum.m b/clang/test/Sema/ns_error_enum.m
index 0b8f8f7cced003..29027f1a54f03f 100644
--- a/clang/test/Sema/ns_error_enum.m
+++ b/clang/test/Sema/ns_error_enum.m
@@ -53,7 +53,6 @@ typedef NS_ERROR_ENUM(unsigned char, MyCFTypedefErrorEnum, MyCFTypedefErrorDomai
 
 extern char *const WrongErrorDomainType;
 enum __attribute__((ns_error_domain(WrongErrorDomainType))) MyWrongErrorDomainType { MyWrongErrorDomain };
-// expected-error at -1{{domain argument 'WrongErrorDomainType' does not point to an NSString or CFString constant}}
 
 struct __attribute__((ns_error_domain(MyErrorDomain))) MyStructWithErrorDomain {};
 // expected-error at -1{{'ns_error_domain' attribute only applies to enums}}
@@ -68,7 +67,7 @@ typedef NS_ERROR_ENUM(unsigned char, MyCFTypedefErrorEnum, MyCFTypedefErrorDomai
 // expected-error at -1{{'ns_error_domain' attribute takes one argument}}
 
 typedef NS_ERROR_ENUM(unsigned char, MyErrorEnumInvalid, InvalidDomain) {
-	// expected-error at -1{{use of undeclared identifier 'InvalidDomain'}}
+	// expected-error at -1{{domain argument 'InvalidDomain' does not refer to global constant}}
 	MyErrFirstInvalid,
 	MyErrSecondInvalid,
 };
diff --git a/clang/utils/TableGen/ClangAttrEmitter.cpp b/clang/utils/TableGen/ClangAttrEmitter.cpp
index c35490cf2e5aeb..acb5b29217b4f2 100644
--- a/clang/utils/TableGen/ClangAttrEmitter.cpp
+++ b/clang/utils/TableGen/ClangAttrEmitter.cpp
@@ -1414,7 +1414,29 @@ namespace {
     }
   };
 
-} // end anonymous namespace
+  class WrappedAttr : public SimpleArgument {
+  public:
+    WrappedAttr(const Record &Arg, StringRef Attr)
+        : SimpleArgument(Arg, Attr, "Attr *") {}
+
+    void writePCHReadDecls(raw_ostream &OS) const override {
+      OS << "    Attr *" << getLowerName() << " = Record.readAttr();";
+    }
+
+    void writePCHWrite(raw_ostream &OS) const override {
+      OS << "    AddAttr(SA->get" << getUpperName() << "());";
+    }
+
+    void writeDump(raw_ostream &OS) const override {}
+
+    void writeDumpChildren(raw_ostream &OS) const override {
+      OS << "    Visit(SA->get" << getUpperName() << "());\n";
+    }
+
+    void writeHasChildren(raw_ostream &OS) const override { OS << "true"; }
+  };
+
+  } // end anonymous namespace
 
 static std::unique_ptr<Argument>
 createArgument(const Record &Arg, StringRef Attr,
@@ -1470,6 +1492,8 @@ createArgument(const Record &Arg, StringRef Attr,
     Ptr = std::make_unique<VariadicIdentifierArgument>(Arg, Attr);
   else if (ArgName == "VersionArgument")
     Ptr = std::make_unique<VersionArgument>(Arg, Attr);
+  else if (ArgName == "WrappedAttr")
+    Ptr = std::make_unique<WrappedAttr>(Arg, Attr);
   else if (ArgName == "OMPTraitInfoArgument")
     Ptr = std::make_unique<SimpleArgument>(Arg, Attr, "OMPTraitInfo *");
   else if (ArgName == "VariadicOMPInteropInfoArgument")

>From 3a5c9b72627c8a741f057f9bfdca88d272d58759 Mon Sep 17 00:00:00 2001
From: Egor Zhdan <e_zhdan at apple.com>
Date: Tue, 21 Nov 2023 16:34:52 +0000
Subject: [PATCH 2/3] [APINotes] Upstream Parser support for API Notes

---
 .../clang/Basic/DiagnosticParseKinds.td       |  3 +
 clang/include/clang/Lex/Lexer.h               |  2 +-
 clang/include/clang/Parse/Parser.h            | 12 ++++
 clang/include/clang/Sema/Sema.h               |  4 ++
 clang/lib/Parse/ParseDecl.cpp                 | 65 +++++++++++++++++++
 clang/lib/Parse/Parser.cpp                    |  5 ++
 6 files changed, 90 insertions(+), 1 deletion(-)

diff --git a/clang/include/clang/Basic/DiagnosticParseKinds.td b/clang/include/clang/Basic/DiagnosticParseKinds.td
index 088f8b74983c86..5c9de6f590becc 100644
--- a/clang/include/clang/Basic/DiagnosticParseKinds.td
+++ b/clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1609,6 +1609,9 @@ def err_pragma_invalid_keyword : Error<
 def err_pragma_pipeline_invalid_keyword : Error<
     "invalid argument; expected 'disable'">;
 
+// API notes.
+def err_type_unparsed : Error<"unparsed tokens following type">;
+
 // Pragma unroll support.
 def warn_pragma_unroll_cuda_value_in_parens : Warning<
   "argument to '#pragma unroll' should not be in parentheses in CUDA C/C++">,
diff --git a/clang/include/clang/Lex/Lexer.h b/clang/include/clang/Lex/Lexer.h
index 899e665e745465..b6ecc7e5ded9e2 100644
--- a/clang/include/clang/Lex/Lexer.h
+++ b/clang/include/clang/Lex/Lexer.h
@@ -198,11 +198,11 @@ class Lexer : public PreprocessorLexer {
   /// from.  Currently this is only used by _Pragma handling.
   SourceLocation getFileLoc() const { return FileLoc; }
 
-private:
   /// Lex - Return the next token in the file.  If this is the end of file, it
   /// return the tok::eof token.  This implicitly involves the preprocessor.
   bool Lex(Token &Result);
 
+private:
   /// Called when the preprocessor is in 'dependency scanning lexing mode'.
   bool LexDependencyDirectiveToken(Token &Result);
 
diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h
index e50a4d05b45991..08f44af3ad0d58 100644
--- a/clang/include/clang/Parse/Parser.h
+++ b/clang/include/clang/Parse/Parser.h
@@ -3655,6 +3655,18 @@ class Parser : public CodeCompletionHandler {
   ParseConceptDefinition(const ParsedTemplateInfo &TemplateInfo,
                          SourceLocation &DeclEnd);
 
+  /// Parse the given string as a type.
+  ///
+  /// This is a dangerous utility function currently employed only by API notes.
+  /// It is not a general entry-point for safely parsing types from strings.
+  ///
+  /// \param TypeStr The string to be parsed as a type.
+  /// \param Context The name of the context in which this string is being
+  /// parsed, which will be used in diagnostics.
+  /// \param IncludeLoc The location at which this parse was triggered.
+  TypeResult ParseTypeFromString(StringRef TypeStr, StringRef Context,
+                                 SourceLocation IncludeLoc);
+
   //===--------------------------------------------------------------------===//
   // Modules
   DeclGroupPtrTy ParseModuleDecl(Sema::ModuleImportState &ImportState);
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 4ef1fe542ea54f..fd220a840f8145 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -955,6 +955,10 @@ class Sema final {
     OpaqueParser = P;
   }
 
+  /// Callback to the parser to parse a type expressed as a string.
+  std::function<TypeResult(StringRef, StringRef, SourceLocation)>
+      ParseTypeFromStringCallback;
+
   class DelayedDiagnostics;
 
   class DelayedDiagnosticsState {
diff --git a/clang/lib/Parse/ParseDecl.cpp b/clang/lib/Parse/ParseDecl.cpp
index 8d856cc2cf8313..cb71635c5bd80c 100644
--- a/clang/lib/Parse/ParseDecl.cpp
+++ b/clang/lib/Parse/ParseDecl.cpp
@@ -8051,6 +8051,71 @@ bool Parser::TryAltiVecTokenOutOfLine(DeclSpec &DS, SourceLocation Loc,
   return false;
 }
 
+TypeResult Parser::ParseTypeFromString(StringRef TypeStr, StringRef Context,
+                                       SourceLocation IncludeLoc) {
+  // Consume (unexpanded) tokens up to the end-of-directive.
+  SmallVector<Token, 4> Tokens;
+  {
+    // Create a new buffer from which we will parse the type.
+    auto &SourceMgr = PP.getSourceManager();
+    FileID FID = SourceMgr.createFileID(
+        llvm::MemoryBuffer::getMemBufferCopy(TypeStr, Context), SrcMgr::C_User,
+        0, 0, IncludeLoc);
+
+    // Form a new lexer that references the buffer.
+    Lexer L(FID, SourceMgr.getBufferOrFake(FID), PP);
+    L.setParsingPreprocessorDirective(true);
+
+    // Lex the tokens from that buffer.
+    Token Tok;
+    do {
+      L.Lex(Tok);
+      Tokens.push_back(Tok);
+    } while (Tok.isNot(tok::eod));
+  }
+
+  // Replace the "eod" token with an "eof" token identifying the end of
+  // the provided string.
+  Token &EndToken = Tokens.back();
+  EndToken.startToken();
+  EndToken.setKind(tok::eof);
+  EndToken.setLocation(Tok.getLocation());
+  EndToken.setEofData(TypeStr.data());
+
+  // Add the current token back.
+  Tokens.push_back(Tok);
+
+  // Enter the tokens into the token stream.
+  PP.EnterTokenStream(Tokens, /*DisableMacroExpansion=*/false,
+                      /*IsReinject=*/false);
+
+  // Consume the current token so that we'll start parsing the tokens we
+  // added to the stream.
+  ConsumeAnyToken();
+
+  // Enter a new scope.
+  ParseScope LocalScope(this, 0);
+
+  // Parse the type.
+  TypeResult Result = ParseTypeName(nullptr);
+
+  // Check if we parsed the whole thing.
+  if (Result.isUsable() &&
+      (Tok.isNot(tok::eof) || Tok.getEofData() != TypeStr.data())) {
+    Diag(Tok.getLocation(), diag::err_type_unparsed);
+  }
+
+  // There could be leftover tokens (e.g. because of an error).
+  // Skip through until we reach the 'end of directive' token.
+  while (Tok.isNot(tok::eof))
+    ConsumeAnyToken();
+
+  // Consume the end token.
+  if (Tok.is(tok::eof) && Tok.getEofData() == TypeStr.data())
+    ConsumeAnyToken();
+  return Result;
+}
+
 void Parser::DiagnoseBitIntUse(const Token &Tok) {
   // If the token is for _ExtInt, diagnose it as being deprecated. Otherwise,
   // the token is about _BitInt and gets (potentially) diagnosed as use of an
diff --git a/clang/lib/Parse/Parser.cpp b/clang/lib/Parse/Parser.cpp
index b703c2d9b8e04d..0b092181bca7bb 100644
--- a/clang/lib/Parse/Parser.cpp
+++ b/clang/lib/Parse/Parser.cpp
@@ -70,6 +70,11 @@ Parser::Parser(Preprocessor &pp, Sema &actions, bool skipFunctionBodies)
   PP.addCommentHandler(CommentSemaHandler.get());
 
   PP.setCodeCompletionHandler(*this);
+
+  Actions.ParseTypeFromStringCallback =
+      [this](StringRef TypeStr, StringRef Context, SourceLocation IncludeLoc) {
+        return this->ParseTypeFromString(TypeStr, Context, IncludeLoc);
+      };
 }
 
 DiagnosticBuilder Parser::Diag(SourceLocation Loc, unsigned DiagID) {

>From 19634d4542ed52e4f56c576a96f13d8fa416c813 Mon Sep 17 00:00:00 2001
From: Egor Zhdan <e_zhdan at apple.com>
Date: Tue, 21 Nov 2023 16:42:29 +0000
Subject: [PATCH 3/3] [APINotes] Upstream Sema support for API Notes

---
 clang/include/clang/Sema/Sema.h |  26 +++++
 clang/lib/Sema/SemaDecl.cpp     |  31 ++++++
 clang/lib/Sema/SemaType.cpp     | 180 +++++++++++++++++++-------------
 3 files changed, 165 insertions(+), 72 deletions(-)

diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index fd220a840f8145..b5eabed05c3b9e 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -3021,6 +3021,9 @@ class Sema final {
   ParmVarDecl *BuildParmVarDeclForTypedef(DeclContext *DC,
                                           SourceLocation Loc,
                                           QualType T);
+  QualType AdjustParameterTypeForObjCAutoRefCount(QualType T,
+                                                  SourceLocation NameLoc,
+                                                  TypeSourceInfo *TSInfo);
   ParmVarDecl *CheckParameter(DeclContext *DC, SourceLocation StartLoc,
                               SourceLocation NameLoc, IdentifierInfo *Name,
                               QualType T, TypeSourceInfo *TSInfo,
@@ -4820,6 +4823,29 @@ class Sema final {
   /// Valid types should not have multiple attributes with different CCs.
   const AttributedType *getCallingConvAttributedType(QualType T) const;
 
+  /// Check whether a nullability type specifier can be added to the given
+  /// type through some means not written in source (e.g. API notes).
+  ///
+  /// \param Type The type to which the nullability specifier will be
+  /// added. On success, this type will be updated appropriately.
+  ///
+  /// \param Nullability The nullability specifier to add.
+  ///
+  /// \param DiagLoc The location to use for diagnostics.
+  ///
+  /// \param AllowArrayTypes Whether to accept nullability specifiers on an
+  /// array type (e.g., because it will decay to a pointer).
+  ///
+  /// \param OverrideExisting Whether to override an existing, locally-specified
+  /// nullability specifier rather than complaining about the conflict.
+  ///
+  /// \returns true if nullability cannot be applied, false otherwise.
+  bool CheckImplicitNullabilityTypeSpecifier(QualType &Type,
+                                             NullabilityKind Nullability,
+                                             SourceLocation DiagLoc,
+                                             bool AllowArrayTypes,
+                                             bool OverrideExisting);
+
   /// Process the attributes before creating an attributed statement. Returns
   /// the semantic attributes that have been processed.
   void ProcessStmtAttributes(Stmt *Stmt, const ParsedAttributes &InAttrs,
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 4e7049571eeb7a..e7d56bfde7eb2c 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -15171,6 +15171,37 @@ void Sema::DiagnoseSizeOfParametersAndReturnValue(
   }
 }
 
+QualType Sema::AdjustParameterTypeForObjCAutoRefCount(QualType T,
+                                                      SourceLocation NameLoc,
+                                                      TypeSourceInfo *TSInfo) {
+  // In ARC, infer a lifetime qualifier for appropriate parameter types.
+  if (!getLangOpts().ObjCAutoRefCount ||
+      T.getObjCLifetime() != Qualifiers::OCL_None || !T->isObjCLifetimeType())
+    return T;
+
+  Qualifiers::ObjCLifetime Lifetime;
+
+  // Special cases for arrays:
+  //   - if it's const, use __unsafe_unretained
+  //   - otherwise, it's an error
+  if (T->isArrayType()) {
+    if (!T.isConstQualified()) {
+      if (DelayedDiagnostics.shouldDelayDiagnostics())
+        DelayedDiagnostics.add(sema::DelayedDiagnostic::makeForbiddenType(
+            NameLoc, diag::err_arc_array_param_no_ownership, T, false));
+      else
+        Diag(NameLoc, diag::err_arc_array_param_no_ownership)
+            << TSInfo->getTypeLoc().getSourceRange();
+    }
+    Lifetime = Qualifiers::OCL_ExplicitNone;
+  } else {
+    Lifetime = T->getObjCARCImplicitLifetime();
+  }
+  T = Context.getLifetimeQualifiedType(T, Lifetime);
+
+  return T;
+}
+
 ParmVarDecl *Sema::CheckParameter(DeclContext *DC, SourceLocation StartLoc,
                                   SourceLocation NameLoc, IdentifierInfo *Name,
                                   QualType T, TypeSourceInfo *TSInfo,
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index 78702b41ab8200..ba96623fbc1e2a 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -7521,6 +7521,25 @@ static bool HandleWebAssemblyFuncrefAttr(TypeProcessingState &State,
   return false;
 }
 
+/// Rebuild an attributed type without the nullability attribute on it.
+static QualType rebuildAttributedTypeWithoutNullability(ASTContext &Ctx,
+                                                        QualType Type) {
+  auto Attributed = dyn_cast<AttributedType>(Type.getTypePtr());
+  if (!Attributed)
+    return Type;
+
+  // Skip the nullability attribute; we're done.
+  if (Attributed->getImmediateNullability())
+    return Attributed->getModifiedType();
+
+  // Build the modified type.
+  QualType Modified = rebuildAttributedTypeWithoutNullability(
+      Ctx, Attributed->getModifiedType());
+  assert(Modified.getTypePtr() != Attributed->getModifiedType().getTypePtr());
+  return Ctx.getAttributedType(Attributed->getAttrKind(), Modified,
+                               Attributed->getEquivalentType());
+}
+
 /// Map a nullability attribute kind to a nullability kind.
 static NullabilityKind mapNullabilityAttrKind(ParsedAttr::Kind kind) {
   switch (kind) {
@@ -7541,74 +7560,65 @@ static NullabilityKind mapNullabilityAttrKind(ParsedAttr::Kind kind) {
   }
 }
 
-/// Applies a nullability type specifier to the given type, if possible.
-///
-/// \param state The type processing state.
-///
-/// \param type The type to which the nullability specifier will be
-/// added. On success, this type will be updated appropriately.
-///
-/// \param attr The attribute as written on the type.
-///
-/// \param allowOnArrayType Whether to accept nullability specifiers on an
-/// array type (e.g., because it will decay to a pointer).
-///
-/// \returns true if a problem has been diagnosed, false on success.
-static bool checkNullabilityTypeSpecifier(TypeProcessingState &state,
-                                          QualType &type,
-                                          ParsedAttr &attr,
-                                          bool allowOnArrayType) {
-  Sema &S = state.getSema();
-
-  NullabilityKind nullability = mapNullabilityAttrKind(attr.getKind());
-  SourceLocation nullabilityLoc = attr.getLoc();
-  bool isContextSensitive = attr.isContextSensitiveKeywordAttribute();
-
-  recordNullabilitySeen(S, nullabilityLoc);
+static bool CheckNullabilityTypeSpecifier(
+    Sema &S, TypeProcessingState *State, ParsedAttr *PAttr, QualType &QT,
+    NullabilityKind Nullability, SourceLocation NullabilityLoc,
+    bool IsContextSensitive, bool AllowOnArrayType, bool OverrideExisting) {
+  bool Implicit = (State == nullptr);
+  if (!Implicit)
+    recordNullabilitySeen(S, NullabilityLoc);
 
   // Check for existing nullability attributes on the type.
-  QualType desugared = type;
-  while (auto attributed = dyn_cast<AttributedType>(desugared.getTypePtr())) {
+  QualType Desugared = QT;
+  while (auto *Attributed = dyn_cast<AttributedType>(Desugared.getTypePtr())) {
     // Check whether there is already a null
-    if (auto existingNullability = attributed->getImmediateNullability()) {
+    if (auto ExistingNullability = Attributed->getImmediateNullability()) {
       // Duplicated nullability.
-      if (nullability == *existingNullability) {
-        S.Diag(nullabilityLoc, diag::warn_nullability_duplicate)
-          << DiagNullabilityKind(nullability, isContextSensitive)
-          << FixItHint::CreateRemoval(nullabilityLoc);
+      if (Nullability == *ExistingNullability) {
+        if (Implicit)
+          break;
+
+        S.Diag(NullabilityLoc, diag::warn_nullability_duplicate)
+            << DiagNullabilityKind(Nullability, IsContextSensitive)
+            << FixItHint::CreateRemoval(NullabilityLoc);
 
         break;
       }
 
-      // Conflicting nullability.
-      S.Diag(nullabilityLoc, diag::err_nullability_conflicting)
-        << DiagNullabilityKind(nullability, isContextSensitive)
-        << DiagNullabilityKind(*existingNullability, false);
-      return true;
+      if (!OverrideExisting) {
+        // Conflicting nullability.
+        S.Diag(NullabilityLoc, diag::err_nullability_conflicting)
+            << DiagNullabilityKind(Nullability, IsContextSensitive)
+            << DiagNullabilityKind(*ExistingNullability, false);
+        return true;
+      }
+
+      // Rebuild the attributed type, dropping the existing nullability.
+      QT = rebuildAttributedTypeWithoutNullability(S.Context, QT);
     }
 
-    desugared = attributed->getModifiedType();
+    Desugared = Attributed->getModifiedType();
   }
 
   // If there is already a different nullability specifier, complain.
   // This (unlike the code above) looks through typedefs that might
   // have nullability specifiers on them, which means we cannot
   // provide a useful Fix-It.
-  if (auto existingNullability = desugared->getNullability()) {
-    if (nullability != *existingNullability) {
-      S.Diag(nullabilityLoc, diag::err_nullability_conflicting)
-        << DiagNullabilityKind(nullability, isContextSensitive)
-        << DiagNullabilityKind(*existingNullability, false);
+  if (auto ExistingNullability = Desugared->getNullability()) {
+    if (Nullability != *ExistingNullability && !Implicit) {
+      S.Diag(NullabilityLoc, diag::err_nullability_conflicting)
+          << DiagNullabilityKind(Nullability, IsContextSensitive)
+          << DiagNullabilityKind(*ExistingNullability, false);
 
       // Try to find the typedef with the existing nullability specifier.
-      if (auto typedefType = desugared->getAs<TypedefType>()) {
-        TypedefNameDecl *typedefDecl = typedefType->getDecl();
+      if (auto TT = Desugared->getAs<TypedefType>()) {
+        TypedefNameDecl *typedefDecl = TT->getDecl();
         QualType underlyingType = typedefDecl->getUnderlyingType();
-        if (auto typedefNullability
-              = AttributedType::stripOuterNullability(underlyingType)) {
-          if (*typedefNullability == *existingNullability) {
+        if (auto typedefNullability =
+                AttributedType::stripOuterNullability(underlyingType)) {
+          if (*typedefNullability == *ExistingNullability) {
             S.Diag(typedefDecl->getLocation(), diag::note_nullability_here)
-              << DiagNullabilityKind(*existingNullability, false);
+                << DiagNullabilityKind(*ExistingNullability, false);
           }
         }
       }
@@ -7618,44 +7628,73 @@ static bool checkNullabilityTypeSpecifier(TypeProcessingState &state,
   }
 
   // If this definitely isn't a pointer type, reject the specifier.
-  if (!desugared->canHaveNullability() &&
-      !(allowOnArrayType && desugared->isArrayType())) {
-    S.Diag(nullabilityLoc, diag::err_nullability_nonpointer)
-      << DiagNullabilityKind(nullability, isContextSensitive) << type;
+  if (!Desugared->canHaveNullability() &&
+      !(AllowOnArrayType && Desugared->isArrayType())) {
+    if (!Implicit)
+      S.Diag(NullabilityLoc, diag::err_nullability_nonpointer)
+          << DiagNullabilityKind(Nullability, IsContextSensitive) << QT;
+
     return true;
   }
 
   // For the context-sensitive keywords/Objective-C property
   // attributes, require that the type be a single-level pointer.
-  if (isContextSensitive) {
+  if (IsContextSensitive) {
     // Make sure that the pointee isn't itself a pointer type.
     const Type *pointeeType = nullptr;
-    if (desugared->isArrayType())
-      pointeeType = desugared->getArrayElementTypeNoTypeQual();
-    else if (desugared->isAnyPointerType())
-      pointeeType = desugared->getPointeeType().getTypePtr();
+    if (Desugared->isArrayType())
+      pointeeType = Desugared->getArrayElementTypeNoTypeQual();
+    else if (Desugared->isAnyPointerType())
+      pointeeType = Desugared->getPointeeType().getTypePtr();
 
     if (pointeeType && (pointeeType->isAnyPointerType() ||
                         pointeeType->isObjCObjectPointerType() ||
                         pointeeType->isMemberPointerType())) {
-      S.Diag(nullabilityLoc, diag::err_nullability_cs_multilevel)
-        << DiagNullabilityKind(nullability, true)
-        << type;
-      S.Diag(nullabilityLoc, diag::note_nullability_type_specifier)
-        << DiagNullabilityKind(nullability, false)
-        << type
-        << FixItHint::CreateReplacement(nullabilityLoc,
-                                        getNullabilitySpelling(nullability));
+      S.Diag(NullabilityLoc, diag::err_nullability_cs_multilevel)
+          << DiagNullabilityKind(Nullability, true) << QT;
+      S.Diag(NullabilityLoc, diag::note_nullability_type_specifier)
+          << DiagNullabilityKind(Nullability, false) << QT
+          << FixItHint::CreateReplacement(NullabilityLoc,
+                                          getNullabilitySpelling(Nullability));
       return true;
     }
   }
 
   // Form the attributed type.
-  type = state.getAttributedType(
-      createNullabilityAttr(S.Context, attr, nullability), type, type);
+  if (State) {
+    assert(PAttr);
+    Attr *A = createNullabilityAttr(S.Context, *PAttr, Nullability);
+    QT = State->getAttributedType(A, QT, QT);
+  } else {
+    attr::Kind attrKind = AttributedType::getNullabilityAttrKind(Nullability);
+    QT = S.Context.getAttributedType(attrKind, QT, QT);
+  }
   return false;
 }
 
+static bool CheckNullabilityTypeSpecifier(TypeProcessingState &State,
+                                          QualType &Type, ParsedAttr &Attr,
+                                          bool AllowOnArrayType) {
+  NullabilityKind Nullability = mapNullabilityAttrKind(Attr.getKind());
+  SourceLocation NullabilityLoc = Attr.getLoc();
+  bool IsContextSensitive = Attr.isContextSensitiveKeywordAttribute();
+
+  return CheckNullabilityTypeSpecifier(State.getSema(), &State, &Attr, Type,
+                                       Nullability, NullabilityLoc,
+                                       IsContextSensitive, AllowOnArrayType,
+                                       /*overrideExisting*/ false);
+}
+
+bool Sema::CheckImplicitNullabilityTypeSpecifier(QualType &Type,
+                                                 NullabilityKind Nullability,
+                                                 SourceLocation DiagLoc,
+                                                 bool AllowArrayTypes,
+                                                 bool OverrideExisting) {
+  return CheckNullabilityTypeSpecifier(
+      *this, nullptr, nullptr, Type, Nullability, DiagLoc,
+      /*isContextSensitive*/ false, AllowArrayTypes, OverrideExisting);
+}
+
 /// Check the application of the Objective-C '__kindof' qualifier to
 /// the given type.
 static bool checkObjCKindOfType(TypeProcessingState &state, QualType &type,
@@ -8950,11 +8989,8 @@ static void processTypeAttrs(TypeProcessingState &state, QualType &type,
         bool allowOnArrayType =
             state.getDeclarator().isPrototypeContext() &&
             !hasOuterPointerLikeChunk(state.getDeclarator(), endIndex);
-        if (checkNullabilityTypeSpecifier(
-              state,
-              type,
-              attr,
-              allowOnArrayType)) {
+        if (CheckNullabilityTypeSpecifier(state, type, attr,
+                                          allowOnArrayType)) {
           attr.setInvalid();
         }
 



More information about the cfe-commits mailing list