[clang] [clang-tools-extra] [clang] Extend diagnose_if to accept more detailed warning information (PR #70976)

Nikolas Klauser via cfe-commits cfe-commits at lists.llvm.org
Sat Jun 29 01:22:34 PDT 2024


https://github.com/philnik777 updated https://github.com/llvm/llvm-project/pull/70976

>From a91f499900d4cea4804833d004b6c4e54a7d8b15 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Sun, 3 Sep 2023 17:26:28 -0700
Subject: [PATCH 1/2] [clang] Extend diagnose_if to accept more detailed
 warning information

---
 .../clang-tidy/ClangTidyDiagnosticConsumer.h  |   4 +
 clang-tools-extra/clangd/Diagnostics.cpp      |   6 +-
 clang-tools-extra/clangd/ParsedAST.cpp        |   2 +-
 clang/include/clang/Basic/Attr.td             |  13 +-
 clang/include/clang/Basic/Diagnostic.h        |   9 +-
 .../clang/Basic/DiagnosticCategories.h        |   1 +
 clang/include/clang/Basic/DiagnosticIDs.h     | 106 ++++++--
 .../clang/Basic/DiagnosticSemaKinds.td        |   6 +
 clang/lib/Basic/Diagnostic.cpp                |  15 +-
 clang/lib/Basic/DiagnosticIDs.cpp             | 232 ++++++++++--------
 clang/lib/Frontend/LogDiagnosticPrinter.cpp   |   4 +-
 .../Frontend/SerializedDiagnosticPrinter.cpp  |   3 +-
 clang/lib/Frontend/TextDiagnosticPrinter.cpp  |   8 +-
 clang/lib/Sema/Sema.cpp                       |   4 +-
 clang/lib/Sema/SemaCUDA.cpp                   |   4 +-
 clang/lib/Sema/SemaDeclAttr.cpp               |  22 +-
 clang/lib/Sema/SemaOverload.cpp               |  24 +-
 .../lib/Sema/SemaTemplateInstantiateDecl.cpp  |   3 +-
 clang/lib/Serialization/ASTReader.cpp         |   2 +-
 clang/lib/Serialization/ASTWriter.cpp         |   2 +-
 .../SemaCXX/diagnose_if-warning-group.cpp     |  35 +++
 clang/tools/diagtool/ListWarnings.cpp         |   7 +-
 clang/tools/diagtool/ShowEnabledWarnings.cpp  |   6 +-
 clang/tools/libclang/CXStoredDiagnostic.cpp   |  15 +-
 24 files changed, 353 insertions(+), 180 deletions(-)
 create mode 100644 clang/test/SemaCXX/diagnose_if-warning-group.cpp

diff --git a/clang-tools-extra/clang-tidy/ClangTidyDiagnosticConsumer.h b/clang-tools-extra/clang-tidy/ClangTidyDiagnosticConsumer.h
index 9280eb1e1f218..c7694ad05f03e 100644
--- a/clang-tools-extra/clang-tidy/ClangTidyDiagnosticConsumer.h
+++ b/clang-tools-extra/clang-tidy/ClangTidyDiagnosticConsumer.h
@@ -79,6 +79,10 @@ class ClangTidyContext {
     this->DiagEngine = DiagEngine;
   }
 
+  const DiagnosticsEngine* getDiagnosticsEngine() const {
+    return DiagEngine;
+  }
+
   ~ClangTidyContext();
 
   /// Report any errors detected using this method.
diff --git a/clang-tools-extra/clangd/Diagnostics.cpp b/clang-tools-extra/clangd/Diagnostics.cpp
index 704e61b1e4dd7..0962fd971342f 100644
--- a/clang-tools-extra/clangd/Diagnostics.cpp
+++ b/clang-tools-extra/clangd/Diagnostics.cpp
@@ -579,7 +579,9 @@ std::vector<Diag> StoreDiags::take(const clang::tidy::ClangTidyContext *Tidy) {
   for (auto &Diag : Output) {
     if (const char *ClangDiag = getDiagnosticCode(Diag.ID)) {
       // Warnings controlled by -Wfoo are better recognized by that name.
-      StringRef Warning = DiagnosticIDs::getWarningOptionForDiag(Diag.ID);
+      StringRef Warning = Tidy->getDiagnosticsEngine()
+                              ->getDiagnosticIDs()
+                              ->getWarningOptionForDiag(Diag.ID);
       if (!Warning.empty()) {
         Diag.Name = ("-W" + Warning).str();
       } else {
@@ -909,7 +911,7 @@ bool isBuiltinDiagnosticSuppressed(unsigned ID,
     if (Suppress.contains(normalizeSuppressedCode(CodePtr)))
       return true;
   }
-  StringRef Warning = DiagnosticIDs::getWarningOptionForDiag(ID);
+  StringRef Warning = DiagnosticIDs{}.getWarningOptionForDiag(ID);
   if (!Warning.empty() && Suppress.contains(Warning))
     return true;
   return false;
diff --git a/clang-tools-extra/clangd/ParsedAST.cpp b/clang-tools-extra/clangd/ParsedAST.cpp
index edd0f77b1031e..57d21fa271179 100644
--- a/clang-tools-extra/clangd/ParsedAST.cpp
+++ b/clang-tools-extra/clangd/ParsedAST.cpp
@@ -340,7 +340,7 @@ void applyWarningOptions(llvm::ArrayRef<std::string> ExtraArgs,
       if (Enable) {
         if (Diags.getDiagnosticLevel(ID, SourceLocation()) <
             DiagnosticsEngine::Warning) {
-          auto Group = DiagnosticIDs::getGroupForDiag(ID);
+          auto Group = Diags.getDiagnosticIDs()->getGroupForDiag(ID);
           if (!Group || !EnabledGroups(*Group))
             continue;
           Diags.setSeverity(ID, diag::Severity::Warning, SourceLocation());
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 25231c5b82b90..e08b7720508d4 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -2959,18 +2959,15 @@ def DiagnoseIf : InheritableAttr {
   let Spellings = [GNU<"diagnose_if">];
   let Subjects = SubjectList<[Function, ObjCMethod, ObjCProperty]>;
   let Args = [ExprArgument<"Cond">, StringArgument<"Message">,
-              EnumArgument<"DiagnosticType",
-                           "DiagnosticType",
-                           ["error", "warning"],
-                           ["DT_Error", "DT_Warning"]>,
+              EnumArgument<"DefaultSeverity",
+                           "DefaultSeverity",
+                           ["error",    "warning"],
+                           ["DS_error", "DS_warning"]>,
+              StringArgument<"WarningGroup", /*optional*/ 1>,
               BoolArgument<"ArgDependent", 0, /*fake*/ 1>,
               DeclArgument<Named, "Parent", 0, /*fake*/ 1>];
   let InheritEvenIfAlreadyPresent = 1;
   let LateParsed = 1;
-  let AdditionalMembers = [{
-    bool isError() const { return diagnosticType == DT_Error; }
-    bool isWarning() const { return diagnosticType == DT_Warning; }
-  }];
   let TemplateDependent = 1;
   let Documentation = [DiagnoseIfDocs];
 }
diff --git a/clang/include/clang/Basic/Diagnostic.h b/clang/include/clang/Basic/Diagnostic.h
index 3df037b793b39..168e403a798a9 100644
--- a/clang/include/clang/Basic/Diagnostic.h
+++ b/clang/include/clang/Basic/Diagnostic.h
@@ -331,10 +331,12 @@ class DiagnosticsEngine : public RefCountedBase<DiagnosticsEngine> {
     // Map extensions to warnings or errors?
     diag::Severity ExtBehavior = diag::Severity::Ignored;
 
-    DiagState()
+    DiagnosticIDs& DiagIDs;
+
+    DiagState(DiagnosticIDs &DiagIDs)
         : IgnoreAllWarnings(false), EnableAllWarnings(false),
           WarningsAsErrors(false), ErrorsAsFatal(false),
-          SuppressSystemWarnings(false) {}
+          SuppressSystemWarnings(false), DiagIDs(DiagIDs) {}
 
     using iterator = llvm::DenseMap<unsigned, DiagnosticMapping>::iterator;
     using const_iterator =
@@ -865,7 +867,8 @@ class DiagnosticsEngine : public RefCountedBase<DiagnosticsEngine> {
   /// \param FormatString A fixed diagnostic format string that will be hashed
   /// and mapped to a unique DiagID.
   template <unsigned N>
-  unsigned getCustomDiagID(Level L, const char (&FormatString)[N]) {
+  [[deprecated("Use a CustomDiagDesc instead of a Level")]] unsigned
+  getCustomDiagID(Level L, const char (&FormatString)[N]) {
     return Diags->getCustomDiagID((DiagnosticIDs::Level)L,
                                   StringRef(FormatString, N - 1));
   }
diff --git a/clang/include/clang/Basic/DiagnosticCategories.h b/clang/include/clang/Basic/DiagnosticCategories.h
index 14be326f7515f..e9a1fe87202d8 100644
--- a/clang/include/clang/Basic/DiagnosticCategories.h
+++ b/clang/include/clang/Basic/DiagnosticCategories.h
@@ -26,6 +26,7 @@ namespace clang {
 #include "clang/Basic/DiagnosticGroups.inc"
 #undef CATEGORY
 #undef DIAG_ENTRY
+      NUM_GROUPS
     };
   }  // end namespace diag
 }  // end namespace clang
diff --git a/clang/include/clang/Basic/DiagnosticIDs.h b/clang/include/clang/Basic/DiagnosticIDs.h
index 06ef1c6904c31..632d8411affdf 100644
--- a/clang/include/clang/Basic/DiagnosticIDs.h
+++ b/clang/include/clang/Basic/DiagnosticIDs.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_CLANG_BASIC_DIAGNOSTICIDS_H
 #define LLVM_CLANG_BASIC_DIAGNOSTICIDS_H
 
+#include "clang/Basic/DiagnosticCategories.h"
 #include "clang/Basic/LLVM.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
 #include "llvm/ADT/StringRef.h"
@@ -80,7 +81,7 @@ namespace clang {
     /// to either Ignore (nothing), Remark (emit a remark), Warning
     /// (emit a warning) or Error (emit as an error).  It allows clients to
     /// map ERRORs to Error or Fatal (stop emitting diagnostics after this one).
-    enum class Severity {
+    enum class Severity : uint8_t {
       // NOTE: 0 means "uncomputed".
       Ignored = 1, ///< Do not present this diagnostic, ignore it.
       Remark = 2,  ///< Present this diagnostic as a remark.
@@ -171,13 +172,61 @@ class DiagnosticMapping {
 class DiagnosticIDs : public RefCountedBase<DiagnosticIDs> {
 public:
   /// The level of the diagnostic, after it has been through mapping.
-  enum Level {
+  enum Level : uint8_t {
     Ignored, Note, Remark, Warning, Error, Fatal
   };
 
+  // Diagnostic classes.
+  enum Class {
+    CLASS_NOTE       = 0x01,
+    CLASS_REMARK     = 0x02,
+    CLASS_WARNING    = 0x03,
+    CLASS_EXTENSION  = 0x04,
+    CLASS_ERROR      = 0x05
+  };
+
+  class CustomDiagDesc {
+    diag::Severity DefaultSeverity : 3;
+    unsigned Class : 3;
+    unsigned ShowInSystemHeader : 1;
+    unsigned ShowInSystemMacro : 1;
+    unsigned HasGroup : 1;
+    diag::Group Group;
+    std::string Description;
+
+    auto get_as_tuple() const {
+      return std::tuple(Class, ShowInSystemHeader, ShowInSystemMacro, HasGroup,
+                      Group, std::string_view{Description});
+    }
+
+  public:
+    CustomDiagDesc(diag::Severity DefaultSeverity, std::string Description,
+                   unsigned Class = CLASS_WARNING,
+                   bool ShowInSystemHeader = false,
+                   bool ShowInSystemMacro = false,
+                   std::optional<diag::Group> Group = std::nullopt)
+        : DefaultSeverity(DefaultSeverity), Class(Class),
+          ShowInSystemHeader(ShowInSystemHeader),
+          ShowInSystemMacro(ShowInSystemMacro), HasGroup(Group != std::nullopt),
+          Group(Group.value_or(diag::Group{})) {}
+
+    friend bool operator==(const CustomDiagDesc &lhs,
+                           const CustomDiagDesc &rhs) {
+      return lhs.get_as_tuple() == rhs.get_as_tuple();
+    }
+
+    friend bool operator<(const CustomDiagDesc &lhs,
+                          const CustomDiagDesc &rhs) {
+      return lhs.get_as_tuple() < rhs.get_as_tuple();
+    }
+  };
+
 private:
   /// Information for uniquing and looking up custom diags.
   std::unique_ptr<diag::CustomDiagInfo> CustomDiagInfo;
+  std::unique_ptr<diag::Severity[]> GroupSeverity =
+      std::make_unique<diag::Severity[]>(
+          static_cast<size_t>(diag::Group::NUM_GROUPS));
 
 public:
   DiagnosticIDs();
@@ -192,7 +241,27 @@ class DiagnosticIDs : public RefCountedBase<DiagnosticIDs> {
   // FIXME: Replace this function with a create-only facilty like
   // createCustomDiagIDFromFormatString() to enforce safe usage. At the time of
   // writing, nearly all callers of this function were invalid.
-  unsigned getCustomDiagID(Level L, StringRef FormatString);
+  unsigned getCustomDiagID(CustomDiagDesc Diag);
+
+  [[deprecated("Use a CustomDiagDesc instead of a Level")]] unsigned
+  getCustomDiagID(Level Level, StringRef Message) {
+    return getCustomDiagID([&] -> CustomDiagDesc {
+      switch (Level) {
+      case DiagnosticIDs::Level::Ignored:
+        return {diag::Severity::Ignored, std::string(Message), CLASS_WARNING};
+      case DiagnosticIDs::Level::Note:
+        return {diag::Severity::Fatal, std::string(Message), CLASS_NOTE};
+      case DiagnosticIDs::Level::Remark:
+        return {diag::Severity::Remark, std::string(Message), CLASS_REMARK};
+      case DiagnosticIDs::Level::Warning:
+        return {diag::Severity::Warning, std::string(Message), CLASS_WARNING};
+      case DiagnosticIDs::Level::Error:
+        return {diag::Severity::Error, std::string(Message), CLASS_ERROR};
+      case DiagnosticIDs::Level::Fatal:
+        return {diag::Severity::Fatal, std::string(Message), CLASS_ERROR};
+      }
+    }());
+  }
 
   //===--------------------------------------------------------------------===//
   // Diagnostic classification and reporting interfaces.
@@ -204,35 +273,34 @@ class DiagnosticIDs : public RefCountedBase<DiagnosticIDs> {
   /// Return true if the unmapped diagnostic levelof the specified
   /// diagnostic ID is a Warning or Extension.
   ///
-  /// This only works on builtin diagnostics, not custom ones, and is not
-  /// legal to call on NOTEs.
-  static bool isBuiltinWarningOrExtension(unsigned DiagID);
+  /// This is not legal to call on NOTEs.
+  bool isWarningOrExtension(unsigned DiagID) const;
 
   /// Return true if the specified diagnostic is mapped to errors by
   /// default.
-  static bool isDefaultMappingAsError(unsigned DiagID);
+  bool isDefaultMappingAsError(unsigned DiagID) const;
 
   /// Get the default mapping for this diagnostic.
-  static DiagnosticMapping getDefaultMapping(unsigned DiagID);
+  DiagnosticMapping getDefaultMapping(unsigned DiagID) const;
 
-  /// Determine whether the given built-in diagnostic ID is a Note.
-  static bool isBuiltinNote(unsigned DiagID);
+  /// Determine whether the given diagnostic ID is a Note.
+  bool isNote(unsigned DiagID) const;
 
-  /// Determine whether the given built-in diagnostic ID is for an
+  /// Determine whether the given diagnostic ID is for an
   /// extension of some sort.
-  static bool isBuiltinExtensionDiag(unsigned DiagID) {
+  bool isExtensionDiag(unsigned DiagID) const {
     bool ignored;
-    return isBuiltinExtensionDiag(DiagID, ignored);
+    return isExtensionDiag(DiagID, ignored);
   }
 
-  /// Determine whether the given built-in diagnostic ID is for an
+  /// Determine whether the given diagnostic ID is for an
   /// extension of some sort, and whether it is enabled by default.
   ///
   /// This also returns EnabledByDefault, which is set to indicate whether the
   /// diagnostic is ignored by default (in which case -pedantic enables it) or
   /// treated as a warning/error by default.
   ///
-  static bool isBuiltinExtensionDiag(unsigned DiagID, bool &EnabledByDefault);
+  bool isExtensionDiag(unsigned DiagID, bool &EnabledByDefault) const;
 
   /// Given a group ID, returns the flag that toggles the group.
   /// For example, for Group::DeprecatedDeclarations, returns
@@ -242,19 +310,21 @@ class DiagnosticIDs : public RefCountedBase<DiagnosticIDs> {
   /// Given a diagnostic group ID, return its documentation.
   static StringRef getWarningOptionDocumentation(diag::Group GroupID);
 
+  void setGroupSeverity(StringRef Group, diag::Severity);
+
   /// Given a group ID, returns the flag that toggles the group.
   /// For example, for "deprecated-declarations", returns
   /// Group::DeprecatedDeclarations.
   static std::optional<diag::Group> getGroupForWarningOption(StringRef);
 
   /// Return the lowest-level group that contains the specified diagnostic.
-  static std::optional<diag::Group> getGroupForDiag(unsigned DiagID);
+  std::optional<diag::Group> getGroupForDiag(unsigned DiagID) const;
 
   /// Return the lowest-level warning option that enables the specified
   /// diagnostic.
   ///
   /// If there is no -Wfoo flag that controls the diagnostic, this returns null.
-  static StringRef getWarningOptionForDiag(unsigned DiagID);
+  StringRef getWarningOptionForDiag(unsigned DiagID);
 
   /// Return the category number that a specified \p DiagID belongs to,
   /// or 0 if no category.
@@ -352,6 +422,8 @@ class DiagnosticIDs : public RefCountedBase<DiagnosticIDs> {
   getDiagnosticSeverity(unsigned DiagID, SourceLocation Loc,
                         const DiagnosticsEngine &Diag) const LLVM_READONLY;
 
+  Class getDiagClass(unsigned DiagID) const;
+
   /// Used to report a diagnostic that is finally fully formed.
   ///
   /// \returns \c true if the diagnostic was emitted, \c false if it was
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 453bd8a9a3404..f8d00ff802f1e 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2865,9 +2865,15 @@ def ext_constexpr_function_never_constant_expr : ExtWarn<
   "constant expression">, InGroup<DiagGroup<"invalid-constexpr">>, DefaultError;
 def err_attr_cond_never_constant_expr : Error<
   "%0 attribute expression never produces a constant expression">;
+def err_diagnose_if_unknown_warning : Error<"unknown warning group">;
 def err_diagnose_if_invalid_diagnostic_type : Error<
   "invalid diagnostic type for 'diagnose_if'; use \"error\" or \"warning\" "
   "instead">;
+def err_diagnose_if_unknown_option : Error<"unknown diagnostic option">;
+def err_diagnose_if_expected_equals : Error<
+  "expected '=' after diagnostic option">;
+def err_diagnose_if_unexpected_value : Error<
+  "unexpected value; use 'true' or 'false'">;
 def err_constexpr_body_no_return : Error<
   "no return statement in %select{constexpr|consteval}0 function">;
 def err_constexpr_return_missing_expr : Error<
diff --git a/clang/lib/Basic/Diagnostic.cpp b/clang/lib/Basic/Diagnostic.cpp
index 0208ccc31bd7f..f67cf9c507fdc 100644
--- a/clang/lib/Basic/Diagnostic.cpp
+++ b/clang/lib/Basic/Diagnostic.cpp
@@ -138,7 +138,7 @@ void DiagnosticsEngine::Reset(bool soft /*=false*/) {
 
     // Create a DiagState and DiagStatePoint representing diagnostic changes
     // through command-line.
-    DiagStates.emplace_back();
+    DiagStates.emplace_back(*Diags);
     DiagStatesByLoc.appendFirst(&DiagStates.back());
   }
 }
@@ -167,7 +167,7 @@ DiagnosticsEngine::DiagState::getOrAddMapping(diag::kind Diag) {
 
   // Initialize the entry if we added it.
   if (Result.second)
-    Result.first->second = DiagnosticIDs::getDefaultMapping(Diag);
+    Result.first->second = DiagIDs.getDefaultMapping(Diag);
 
   return Result.first->second;
 }
@@ -309,7 +309,8 @@ void DiagnosticsEngine::DiagStateMap::dump(SourceManager &SrcMgr,
 
       for (auto &Mapping : *Transition.State) {
         StringRef Option =
-            DiagnosticIDs::getWarningOptionForDiag(Mapping.first);
+            SrcMgr.getDiagnostics().Diags->getWarningOptionForDiag(
+                Mapping.first);
         if (!DiagName.empty() && DiagName != Option)
           continue;
 
@@ -353,9 +354,7 @@ void DiagnosticsEngine::PushDiagStatePoint(DiagState *State,
 
 void DiagnosticsEngine::setSeverity(diag::kind Diag, diag::Severity Map,
                                     SourceLocation L) {
-  assert(Diag < diag::DIAG_UPPER_LIMIT &&
-         "Can only map builtin diagnostics");
-  assert((Diags->isBuiltinWarningOrExtension(Diag) ||
+  assert((Diags->isWarningOrExtension(Diag) ||
           (Map == diag::Severity::Fatal || Map == diag::Severity::Error)) &&
          "Cannot map errors into warnings!");
   assert((L.isInvalid() || SourceMgr) && "No SourceMgr for valid location");
@@ -406,6 +405,8 @@ bool DiagnosticsEngine::setSeverityForGroup(diag::Flavor Flavor,
   if (Diags->getDiagnosticsInGroup(Flavor, Group, GroupDiags))
     return true;
 
+  Diags->setGroupSeverity(Group, Map);
+
   // Set the mapping.
   for (diag::kind Diag : GroupDiags)
     setSeverity(Diag, Map, Loc);
@@ -491,7 +492,7 @@ void DiagnosticsEngine::setSeverityForAll(diag::Flavor Flavor,
 
   // Set the mapping.
   for (diag::kind Diag : AllDiags)
-    if (Diags->isBuiltinWarningOrExtension(Diag))
+    if (Diags->isWarningOrExtension(Diag))
       setSeverity(Diag, Map, Loc);
 }
 
diff --git a/clang/lib/Basic/DiagnosticIDs.cpp b/clang/lib/Basic/DiagnosticIDs.cpp
index e5667d57f8cff..6b2ba485533f9 100644
--- a/clang/lib/Basic/DiagnosticIDs.cpp
+++ b/clang/lib/Basic/DiagnosticIDs.cpp
@@ -99,13 +99,12 @@ const uint32_t StaticDiagInfoDescriptionOffsets[] = {
 #undef DIAG
 };
 
-// Diagnostic classes.
 enum {
-  CLASS_NOTE       = 0x01,
-  CLASS_REMARK     = 0x02,
-  CLASS_WARNING    = 0x03,
-  CLASS_EXTENSION  = 0x04,
-  CLASS_ERROR      = 0x05
+  CLASS_NOTE = DiagnosticIDs::CLASS_NOTE,
+  CLASS_REMARK = DiagnosticIDs::CLASS_REMARK,
+  CLASS_WARNING = DiagnosticIDs::CLASS_WARNING,
+  CLASS_EXTENSION = DiagnosticIDs::CLASS_EXTENSION,
+  CLASS_ERROR = DiagnosticIDs::CLASS_ERROR,
 };
 
 struct StaticDiagInfoRec {
@@ -256,11 +255,66 @@ CATEGORY(REFACTORING, ANALYSIS)
   return Found;
 }
 
-DiagnosticMapping DiagnosticIDs::getDefaultMapping(unsigned DiagID) {
+
+//===----------------------------------------------------------------------===//
+// Custom Diagnostic information
+//===----------------------------------------------------------------------===//
+
+namespace clang {
+  namespace diag {
+    using CustomDiagDesc = DiagnosticIDs::CustomDiagDesc;
+    class CustomDiagInfo {
+      std::vector<CustomDiagDesc> DiagInfo;
+      std::map<CustomDiagDesc, unsigned> DiagIDs;
+      std::map<diag::Group, std::vector<unsigned>> GroupToDiags;
+    public:
+
+      /// getDescription - Return the description of the specified custom
+      /// diagnostic.
+      const CustomDiagDesc& getDescription(unsigned DiagID) const {
+        assert(DiagID - DIAG_UPPER_LIMIT < DiagInfo.size() &&
+               "Invalid diagnostic ID");
+        return DiagInfo[DiagID-DIAG_UPPER_LIMIT];
+      }
+
+      unsigned getOrCreateDiagID(DiagnosticIDs::CustomDiagDesc D) {
+        // Check to see if it already exists.
+        std::map<CustomDiagDesc, unsigned>::iterator I = DiagIDs.lower_bound(D);
+        if (I != DiagIDs.end() && I->first == D)
+          return I->second;
+
+        // If not, assign a new ID.
+        unsigned ID = DiagInfo.size()+DIAG_UPPER_LIMIT;
+        DiagIDs.insert(std::make_pair(D, ID));
+        DiagInfo.push_back(D);
+        if (D.HasGroup)
+          GroupToDiags[D.Group].emplace_back(ID);
+        return ID;
+      }
+
+      ArrayRef<unsigned> getDiagsInGroup(diag::Group G) const {
+        if (auto Diags = GroupToDiags.find(G); Diags != GroupToDiags.end())
+          return Diags->second;
+        return {};
+      }
+    };
+
+  } // end diag namespace
+} // end clang namespace
+
+DiagnosticMapping DiagnosticIDs::getDefaultMapping(unsigned DiagID) const {
   DiagnosticMapping Info = DiagnosticMapping::Make(
       diag::Severity::Fatal, /*IsUser=*/false, /*IsPragma=*/false);
 
-  if (const StaticDiagInfoRec *StaticInfo = GetDiagInfo(DiagID)) {
+  if (DiagID >= diag::DIAG_UPPER_LIMIT) {
+    const auto& Diag = CustomDiagInfo->getDescription(DiagID);
+    if (auto GroupSev = GroupSeverity[static_cast<size_t>(Diag.Group)];
+        GroupSev == diag::Severity())
+      Info.setSeverity(Diag.DefaultSeverity);
+    else
+      Info.setSeverity(GroupSev);
+
+  } else if (const StaticDiagInfoRec *StaticInfo = GetDiagInfo(DiagID)) {
     Info.setSeverity((diag::Severity)StaticInfo->DefaultSeverity);
 
     if (StaticInfo->WarnNoWerror) {
@@ -330,61 +384,6 @@ bool DiagnosticIDs::isDeferrable(unsigned DiagID) {
   return false;
 }
 
-/// getBuiltinDiagClass - Return the class field of the diagnostic.
-///
-static unsigned getBuiltinDiagClass(unsigned DiagID) {
-  if (const StaticDiagInfoRec *Info = GetDiagInfo(DiagID))
-    return Info->Class;
-  return ~0U;
-}
-
-//===----------------------------------------------------------------------===//
-// Custom Diagnostic information
-//===----------------------------------------------------------------------===//
-
-namespace clang {
-  namespace diag {
-    class CustomDiagInfo {
-      typedef std::pair<DiagnosticIDs::Level, std::string> DiagDesc;
-      std::vector<DiagDesc> DiagInfo;
-      std::map<DiagDesc, unsigned> DiagIDs;
-    public:
-
-      /// getDescription - Return the description of the specified custom
-      /// diagnostic.
-      StringRef getDescription(unsigned DiagID) const {
-        assert(DiagID - DIAG_UPPER_LIMIT < DiagInfo.size() &&
-               "Invalid diagnostic ID");
-        return DiagInfo[DiagID-DIAG_UPPER_LIMIT].second;
-      }
-
-      /// getLevel - Return the level of the specified custom diagnostic.
-      DiagnosticIDs::Level getLevel(unsigned DiagID) const {
-        assert(DiagID - DIAG_UPPER_LIMIT < DiagInfo.size() &&
-               "Invalid diagnostic ID");
-        return DiagInfo[DiagID-DIAG_UPPER_LIMIT].first;
-      }
-
-      unsigned getOrCreateDiagID(DiagnosticIDs::Level L, StringRef Message,
-                                 DiagnosticIDs &Diags) {
-        DiagDesc D(L, std::string(Message));
-        // Check to see if it already exists.
-        std::map<DiagDesc, unsigned>::iterator I = DiagIDs.lower_bound(D);
-        if (I != DiagIDs.end() && I->first == D)
-          return I->second;
-
-        // If not, assign a new ID.
-        unsigned ID = DiagInfo.size()+DIAG_UPPER_LIMIT;
-        DiagIDs.insert(std::make_pair(D, ID));
-        DiagInfo.push_back(D);
-        return ID;
-      }
-    };
-
-  } // end diag namespace
-} // end clang namespace
-
-
 //===----------------------------------------------------------------------===//
 // Common Diagnostic implementation
 //===----------------------------------------------------------------------===//
@@ -399,38 +398,34 @@ DiagnosticIDs::~DiagnosticIDs() {}
 ///
 /// \param FormatString A fixed diagnostic format string that will be hashed and
 /// mapped to a unique DiagID.
-unsigned DiagnosticIDs::getCustomDiagID(Level L, StringRef FormatString) {
+unsigned DiagnosticIDs::getCustomDiagID(
+    CustomDiagDesc Diag) {
   if (!CustomDiagInfo)
     CustomDiagInfo.reset(new diag::CustomDiagInfo());
-  return CustomDiagInfo->getOrCreateDiagID(L, FormatString, *this);
+  return CustomDiagInfo->getOrCreateDiagID(Diag);
 }
 
-
-/// isBuiltinWarningOrExtension - Return true if the unmapped diagnostic
-/// level of the specified diagnostic ID is a Warning or Extension.
-/// This only works on builtin diagnostics, not custom ones, and is not legal to
-/// call on NOTEs.
-bool DiagnosticIDs::isBuiltinWarningOrExtension(unsigned DiagID) {
-  return DiagID < diag::DIAG_UPPER_LIMIT &&
-         getBuiltinDiagClass(DiagID) != CLASS_ERROR;
+bool DiagnosticIDs::isWarningOrExtension(unsigned DiagID) const {
+  return DiagID < diag::DIAG_UPPER_LIMIT
+             ? getDiagClass(DiagID) != CLASS_ERROR
+             : CustomDiagInfo->getDescription(DiagID).Class != CLASS_ERROR;
 }
 
 /// Determine whether the given built-in diagnostic ID is a
 /// Note.
-bool DiagnosticIDs::isBuiltinNote(unsigned DiagID) {
+bool DiagnosticIDs::isNote(unsigned DiagID) const {
   return DiagID < diag::DIAG_UPPER_LIMIT &&
-    getBuiltinDiagClass(DiagID) == CLASS_NOTE;
+    getDiagClass(DiagID) == CLASS_NOTE;
 }
 
-/// isBuiltinExtensionDiag - Determine whether the given built-in diagnostic
+/// isExtensionDiag - Determine whether the given built-in diagnostic
 /// ID is for an extension of some sort.  This also returns EnabledByDefault,
 /// which is set to indicate whether the diagnostic is ignored by default (in
 /// which case -pedantic enables it) or treated as a warning/error by default.
 ///
-bool DiagnosticIDs::isBuiltinExtensionDiag(unsigned DiagID,
-                                        bool &EnabledByDefault) {
+bool DiagnosticIDs::isExtensionDiag(unsigned DiagID, bool &EnabledByDefault) const {
   if (DiagID >= diag::DIAG_UPPER_LIMIT ||
-      getBuiltinDiagClass(DiagID) != CLASS_EXTENSION)
+      getDiagClass(DiagID) != CLASS_EXTENSION)
     return false;
 
   EnabledByDefault =
@@ -438,7 +433,7 @@ bool DiagnosticIDs::isBuiltinExtensionDiag(unsigned DiagID,
   return true;
 }
 
-bool DiagnosticIDs::isDefaultMappingAsError(unsigned DiagID) {
+bool DiagnosticIDs::isDefaultMappingAsError(unsigned DiagID) const {
   if (DiagID >= diag::DIAG_UPPER_LIMIT)
     return false;
 
@@ -451,7 +446,7 @@ StringRef DiagnosticIDs::getDescription(unsigned DiagID) const {
   if (const StaticDiagInfoRec *Info = GetDiagInfo(DiagID))
     return Info->getDescription();
   assert(CustomDiagInfo && "Invalid CustomDiagInfo");
-  return CustomDiagInfo->getDescription(DiagID);
+  return CustomDiagInfo->getDescription(DiagID).Description;
 }
 
 static DiagnosticIDs::Level toLevel(diag::Severity SV) {
@@ -476,13 +471,7 @@ static DiagnosticIDs::Level toLevel(diag::Severity SV) {
 DiagnosticIDs::Level
 DiagnosticIDs::getDiagnosticLevel(unsigned DiagID, SourceLocation Loc,
                                   const DiagnosticsEngine &Diag) const {
-  // Handle custom diagnostics, which cannot be mapped.
-  if (DiagID >= diag::DIAG_UPPER_LIMIT) {
-    assert(CustomDiagInfo && "Invalid CustomDiagInfo");
-    return CustomDiagInfo->getLevel(DiagID);
-  }
-
-  unsigned DiagClass = getBuiltinDiagClass(DiagID);
+  unsigned DiagClass = getDiagClass(DiagID);
   if (DiagClass == CLASS_NOTE) return DiagnosticIDs::Note;
   return toLevel(getDiagnosticSeverity(DiagID, Loc, Diag));
 }
@@ -496,7 +485,8 @@ DiagnosticIDs::getDiagnosticLevel(unsigned DiagID, SourceLocation Loc,
 diag::Severity
 DiagnosticIDs::getDiagnosticSeverity(unsigned DiagID, SourceLocation Loc,
                                      const DiagnosticsEngine &Diag) const {
-  assert(getBuiltinDiagClass(DiagID) != CLASS_NOTE);
+  const bool IsCustomDiag = DiagID >= diag::DIAG_UPPER_LIMIT;
+  assert(getDiagClass(DiagID) != CLASS_NOTE);
 
   // Specific non-error diagnostics may be mapped to various levels from ignored
   // to error.  Errors can only be mapped to fatal.
@@ -504,7 +494,7 @@ DiagnosticIDs::getDiagnosticSeverity(unsigned DiagID, SourceLocation Loc,
 
   // Get the mapping information, or compute it lazily.
   DiagnosticsEngine::DiagState *State = Diag.GetDiagStateForLoc(Loc);
-  DiagnosticMapping &Mapping = State->getOrAddMapping((diag::kind)DiagID);
+  const DiagnosticMapping Mapping = State->getOrAddMapping((diag::kind)DiagID);
 
   // TODO: Can a null severity really get here?
   if (Mapping.getSeverity() != diag::Severity())
@@ -512,14 +502,15 @@ DiagnosticIDs::getDiagnosticSeverity(unsigned DiagID, SourceLocation Loc,
 
   // Upgrade ignored diagnostics if -Weverything is enabled.
   if (State->EnableAllWarnings && Result == diag::Severity::Ignored &&
-      !Mapping.isUser() && getBuiltinDiagClass(DiagID) != CLASS_REMARK)
+      !Mapping.isUser() &&
+      (IsCustomDiag || getDiagClass(DiagID) != CLASS_REMARK))
     Result = diag::Severity::Warning;
 
   // Ignore -pedantic diagnostics inside __extension__ blocks.
   // (The diagnostics controlled by -pedantic are the extension diagnostics
   // that are not enabled by default.)
   bool EnabledByDefault = false;
-  bool IsExtensionDiag = isBuiltinExtensionDiag(DiagID, EnabledByDefault);
+  bool IsExtensionDiag = isExtensionDiag(DiagID, EnabledByDefault);
   if (Diag.AllExtensionsSilenced && IsExtensionDiag && !EnabledByDefault)
     return diag::Severity::Ignored;
 
@@ -540,7 +531,7 @@ DiagnosticIDs::getDiagnosticSeverity(unsigned DiagID, SourceLocation Loc,
   if (State->IgnoreAllWarnings) {
     if (Result == diag::Severity::Warning ||
         (Result >= diag::Severity::Error &&
-         !isDefaultMappingAsError((diag::kind)DiagID)))
+         (IsCustomDiag || !isDefaultMappingAsError((diag::kind)DiagID))))
       return diag::Severity::Ignored;
   }
 
@@ -562,9 +553,10 @@ DiagnosticIDs::getDiagnosticSeverity(unsigned DiagID, SourceLocation Loc,
       Diag.CurDiagID != diag::fatal_too_many_errors && Diag.FatalsAsError)
     Result = diag::Severity::Error;
 
-  // Custom diagnostics always are emitted in system headers.
   bool ShowInSystemHeader =
-      !GetDiagInfo(DiagID) || GetDiagInfo(DiagID)->WarnShowInSystemHeader;
+      DiagID >= diag::DIAG_UPPER_LIMIT
+          ? CustomDiagInfo->getDescription(DiagID).ShowInSystemHeader
+          : !GetDiagInfo(DiagID) || GetDiagInfo(DiagID)->WarnShowInSystemHeader;
 
   // If we are in a system header, we ignore it. We look at the diagnostic class
   // because we also want to ignore extensions and warnings in -Werror and
@@ -584,6 +576,15 @@ DiagnosticIDs::getDiagnosticSeverity(unsigned DiagID, SourceLocation Loc,
   return Result;
 }
 
+DiagnosticIDs::Class DiagnosticIDs::getDiagClass(unsigned DiagID) const {
+  if (DiagID >= diag::DIAG_UPPER_LIMIT)
+    return Class(CustomDiagInfo->getDescription(DiagID).Class);
+
+  if (const StaticDiagInfoRec *Info = GetDiagInfo(DiagID))
+    return Class(Info->Class);
+  return Class(~0U);
+}
+
 #define GET_DIAG_ARRAYS
 #include "clang/Basic/DiagnosticGroups.inc"
 #undef GET_DIAG_ARRAYS
@@ -629,7 +630,15 @@ DiagnosticIDs::getGroupForWarningOption(StringRef Name) {
   return static_cast<diag::Group>(Found - OptionTable);
 }
 
-std::optional<diag::Group> DiagnosticIDs::getGroupForDiag(unsigned DiagID) {
+std::optional<diag::Group>
+DiagnosticIDs::getGroupForDiag(unsigned DiagID) const {
+  if (DiagID >= diag::DIAG_UPPER_LIMIT) {
+    assert(CustomDiagInfo);
+    auto Diag = CustomDiagInfo->getDescription(DiagID);
+    if (!Diag.HasGroup)
+      return std::nullopt;
+    return Diag.Group;
+  }
   if (const StaticDiagInfoRec *Info = GetDiagInfo(DiagID))
     return static_cast<diag::Group>(Info->getOptionGroupIndex());
   return std::nullopt;
@@ -689,12 +698,33 @@ static bool getDiagnosticsInGroup(diag::Flavor Flavor,
 bool
 DiagnosticIDs::getDiagnosticsInGroup(diag::Flavor Flavor, StringRef Group,
                                      SmallVectorImpl<diag::kind> &Diags) const {
-  if (std::optional<diag::Group> G = getGroupForWarningOption(Group))
+  if (std::optional<diag::Group> G = getGroupForWarningOption(Group)) {
+    if (CustomDiagInfo)
+      llvm::copy(CustomDiagInfo->getDiagsInGroup(*G), std::back_inserter(Diags));
     return ::getDiagnosticsInGroup(
         Flavor, &OptionTable[static_cast<unsigned>(*G)], Diags);
+  }
   return true;
 }
 
+static void setGroupSeverity(const WarningOption *Group,
+                             diag::Severity *GroupSeverity,
+                             diag::Severity Sev) {
+  for (const int16_t *SubGroups = DiagSubGroups + Group->SubGroups;
+       *SubGroups != -1; ++SubGroups) {
+    GroupSeverity[static_cast<size_t>(*SubGroups)] = Sev;
+    setGroupSeverity(&OptionTable[*SubGroups], GroupSeverity, Sev);
+  }
+}
+
+void DiagnosticIDs::setGroupSeverity(StringRef Group, diag::Severity Sev) {
+  if (std::optional<diag::Group> G = getGroupForWarningOption(Group)) {
+    GroupSeverity[static_cast<size_t>(*G)] = Sev;
+    ::setGroupSeverity(&OptionTable[static_cast<size_t>(*G)],
+                       GroupSeverity.get(), Sev);
+  }
+}
+
 void DiagnosticIDs::getAllDiagnostics(diag::Flavor Flavor,
                                       std::vector<diag::kind> &Diags) {
   for (unsigned i = 0; i != StaticDiagInfoSize; ++i)
@@ -830,14 +860,8 @@ void DiagnosticIDs::EmitDiag(DiagnosticsEngine &Diag, Level DiagLevel) const {
 }
 
 bool DiagnosticIDs::isUnrecoverable(unsigned DiagID) const {
-  if (DiagID >= diag::DIAG_UPPER_LIMIT) {
-    assert(CustomDiagInfo && "Invalid CustomDiagInfo");
-    // Custom diagnostics.
-    return CustomDiagInfo->getLevel(DiagID) >= DiagnosticIDs::Error;
-  }
-
   // Only errors may be unrecoverable.
-  if (getBuiltinDiagClass(DiagID) < CLASS_ERROR)
+  if (getDiagClass(DiagID) < CLASS_ERROR)
     return false;
 
   if (DiagID == diag::err_unavailable ||
diff --git a/clang/lib/Frontend/LogDiagnosticPrinter.cpp b/clang/lib/Frontend/LogDiagnosticPrinter.cpp
index 32fc6cb2acd87..926160b1df5bf 100644
--- a/clang/lib/Frontend/LogDiagnosticPrinter.cpp
+++ b/clang/lib/Frontend/LogDiagnosticPrinter.cpp
@@ -129,7 +129,8 @@ void LogDiagnosticPrinter::HandleDiagnostic(DiagnosticsEngine::Level Level,
   DE.DiagnosticLevel = Level;
 
   DE.WarningOption =
-      std::string(DiagnosticIDs::getWarningOptionForDiag(DE.DiagnosticID));
+      std::string(Info.getDiags()->getDiagnosticIDs()->getWarningOptionForDiag(
+          DE.DiagnosticID));
 
   // Format the message.
   SmallString<100> MessageStr;
@@ -160,4 +161,3 @@ void LogDiagnosticPrinter::HandleDiagnostic(DiagnosticsEngine::Level Level,
   // Record the diagnostic entry.
   Entries.push_back(DE);
 }
-
diff --git a/clang/lib/Frontend/SerializedDiagnosticPrinter.cpp b/clang/lib/Frontend/SerializedDiagnosticPrinter.cpp
index b76728acb9077..6ffdc53e1aa69 100644
--- a/clang/lib/Frontend/SerializedDiagnosticPrinter.cpp
+++ b/clang/lib/Frontend/SerializedDiagnosticPrinter.cpp
@@ -540,7 +540,8 @@ unsigned SDiagsWriter::getEmitDiagnosticFlag(DiagnosticsEngine::Level DiagLevel,
   if (DiagLevel == DiagnosticsEngine::Note)
     return 0; // No flag for notes.
 
-  StringRef FlagName = DiagnosticIDs::getWarningOptionForDiag(DiagID);
+  StringRef FlagName =
+      getMetaDiags()->getDiagnosticIDs()->getWarningOptionForDiag(DiagID);
   return getEmitDiagnosticFlag(FlagName);
 }
 
diff --git a/clang/lib/Frontend/TextDiagnosticPrinter.cpp b/clang/lib/Frontend/TextDiagnosticPrinter.cpp
index 0ff5376098ffe..5e748c620161f 100644
--- a/clang/lib/Frontend/TextDiagnosticPrinter.cpp
+++ b/clang/lib/Frontend/TextDiagnosticPrinter.cpp
@@ -70,13 +70,15 @@ static void printDiagnosticOptions(raw_ostream &OS,
     // flag it as such. Note that diagnostics could also have been mapped by a
     // pragma, but we don't currently have a way to distinguish this.
     if (Level == DiagnosticsEngine::Error &&
-        DiagnosticIDs::isBuiltinWarningOrExtension(Info.getID()) &&
-        !DiagnosticIDs::isDefaultMappingAsError(Info.getID())) {
+        Info.getDiags()->getDiagnosticIDs()->isWarningOrExtension(Info.getID()) &&
+        !Info.getDiags()->getDiagnosticIDs()->isDefaultMappingAsError(Info.getID())) {
       OS << " [-Werror";
       Started = true;
     }
 
-    StringRef Opt = DiagnosticIDs::getWarningOptionForDiag(Info.getID());
+    StringRef Opt =
+        Info.getDiags()->getDiagnosticIDs()->getWarningOptionForDiag(
+            Info.getID());
     if (!Opt.empty()) {
       OS << (Started ? "," : " [")
          << (Level == DiagnosticsEngine::Remark ? "-R" : "-W") << Opt;
diff --git a/clang/lib/Sema/Sema.cpp b/clang/lib/Sema/Sema.cpp
index acb765559e6a8..6365faff00ddb 100644
--- a/clang/lib/Sema/Sema.cpp
+++ b/clang/lib/Sema/Sema.cpp
@@ -1618,7 +1618,7 @@ void Sema::EmitCurrentDiagnostic(unsigned DiagID) {
   // that is different from the last template instantiation where
   // we emitted an error, print a template instantiation
   // backtrace.
-  if (!DiagnosticIDs::isBuiltinNote(DiagID))
+  if (!Diags.getDiagnosticIDs()->isNote(DiagID))
     PrintContextStack();
 }
 
@@ -1637,7 +1637,7 @@ bool Sema::hasUncompilableErrorOccurred() const {
   if (Loc == DeviceDeferredDiags.end())
     return false;
   for (auto PDAt : Loc->second) {
-    if (DiagnosticIDs::isDefaultMappingAsError(PDAt.second.getDiagID()))
+    if (Diags.getDiagnosticIDs()->isDefaultMappingAsError(PDAt.second.getDiagID()))
       return true;
   }
   return false;
diff --git a/clang/lib/Sema/SemaCUDA.cpp b/clang/lib/Sema/SemaCUDA.cpp
index d993499cf4a6e..7b1f9f74e3492 100644
--- a/clang/lib/Sema/SemaCUDA.cpp
+++ b/clang/lib/Sema/SemaCUDA.cpp
@@ -773,7 +773,7 @@ Sema::SemaDiagnosticBuilder Sema::CUDADiagIfDeviceCode(SourceLocation Loc,
       // mode until the function is known-emitted.
       if (!getLangOpts().CUDAIsDevice)
         return SemaDiagnosticBuilder::K_Nop;
-      if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID))
+      if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isNote(DiagID))
         return SemaDiagnosticBuilder::K_Immediate;
       return (getEmissionStatus(CurFunContext) ==
               FunctionEmissionStatus::Emitted)
@@ -802,7 +802,7 @@ Sema::SemaDiagnosticBuilder Sema::CUDADiagIfHostCode(SourceLocation Loc,
       // mode until the function is known-emitted.
       if (getLangOpts().CUDAIsDevice)
         return SemaDiagnosticBuilder::K_Nop;
-      if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isBuiltinNote(DiagID))
+      if (IsLastErrorImmediate && Diags.getDiagnosticIDs()->isNote(DiagID))
         return SemaDiagnosticBuilder::K_Immediate;
       return (getEmissionStatus(CurFunContext) ==
               FunctionEmissionStatus::Emitted)
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index fc4e3ccf29a60..ea7dacd2b05a2 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -1112,22 +1112,34 @@ static void handleDiagnoseIfAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
   if (!checkFunctionConditionAttr(S, D, AL, Cond, Msg))
     return;
 
-  StringRef DiagTypeStr;
-  if (!S.checkStringLiteralArgumentAttr(AL, 2, DiagTypeStr))
+  StringRef DefaultSevStr;
+  if (!S.checkStringLiteralArgumentAttr(AL, 2, DefaultSevStr))
     return;
 
-  DiagnoseIfAttr::DiagnosticType DiagType;
-  if (!DiagnoseIfAttr::ConvertStrToDiagnosticType(DiagTypeStr, DiagType)) {
+  DiagnoseIfAttr::DefaultSeverity DefaultSev;
+  if (!DiagnoseIfAttr::ConvertStrToDefaultSeverity(DefaultSevStr, DefaultSev)) {
     S.Diag(AL.getArgAsExpr(2)->getBeginLoc(),
            diag::err_diagnose_if_invalid_diagnostic_type);
     return;
   }
 
+  StringRef WarningGroup;
+  SmallVector<StringRef, 2> Options;
+  if (AL.getNumArgs() > 3) {
+    if (!S.checkStringLiteralArgumentAttr(AL, 3, WarningGroup))
+      return;
+    if (!S.getDiagnostics().getDiagnosticIDs()->getGroupForWarningOption(WarningGroup)) {
+      S.Diag(AL.getArgAsExpr(3)->getBeginLoc(), diag::err_diagnose_if_unknown_warning);
+      return;
+    }
+  }
+
   bool ArgDependent = false;
   if (const auto *FD = dyn_cast<FunctionDecl>(D))
     ArgDependent = ArgumentDependenceChecker(FD).referencesArgs(Cond);
   D->addAttr(::new (S.Context) DiagnoseIfAttr(
-      S.Context, AL, Cond, Msg, DiagType, ArgDependent, cast<NamedDecl>(D)));
+      S.Context, AL, Cond, Msg, DefaultSev, WarningGroup, ArgDependent,
+      cast<NamedDecl>(D)));
 }
 
 static void handleNoBuiltinAttr(Sema &S, Decl *D, const ParsedAttr &AL) {
diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp
index db386fef0661c..e79ec8a657f8c 100644
--- a/clang/lib/Sema/SemaOverload.cpp
+++ b/clang/lib/Sema/SemaOverload.cpp
@@ -7153,8 +7153,9 @@ static bool diagnoseDiagnoseIfAttrsWith(Sema &S, const NamedDecl *ND,
     return false;
 
   auto WarningBegin = std::stable_partition(
-      Attrs.begin(), Attrs.end(),
-      [](const DiagnoseIfAttr *DIA) { return DIA->isError(); });
+      Attrs.begin(), Attrs.end(), [](const DiagnoseIfAttr *DIA) {
+        return DIA->getDefaultSeverity() == DiagnoseIfAttr::DS_error;
+      });
 
   // Note that diagnose_if attributes are late-parsed, so they appear in the
   // correct order (unlike enable_if attributes).
@@ -7170,9 +7171,22 @@ static bool diagnoseDiagnoseIfAttrsWith(Sema &S, const NamedDecl *ND,
 
   for (const auto *DIA : llvm::make_range(WarningBegin, Attrs.end()))
     if (IsSuccessful(DIA)) {
-      S.Diag(Loc, diag::warn_diagnose_if_succeeded) << DIA->getMessage();
-      S.Diag(DIA->getLocation(), diag::note_from_diagnose_if)
-          << DIA->getParent() << DIA->getCond()->getSourceRange();
+      if (DIA->getWarningGroup().empty() &&
+          DIA->getDefaultSeverity() == DiagnoseIfAttr::DS_warning) {
+        S.Diag(Loc, diag::warn_diagnose_if_succeeded) << DIA->getMessage();
+        S.Diag(DIA->getLocation(), diag::note_from_diagnose_if)
+            << DIA->getParent() << DIA->getCond()->getSourceRange();
+      } else {
+        DiagnosticIDs::CustomDiagDesc Diag;
+        auto DiagGroup = S.Diags.getDiagnosticIDs()->getGroupForWarningOption(
+                      DIA->getWarningGroup());
+        assert(DiagGroup);
+        Diag.HasGroup = true;
+        Diag.Group = *DiagGroup;
+        Diag.Description = "%0";
+        auto DiagID = S.Diags.getDiagnosticIDs()->getCustomDiagID(Diag);
+        S.Diag(Loc, DiagID) << DIA->getMessage();
+      }
     }
 
   return false;
diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
index 78a7892a35a32..bb48059c4acad 100644
--- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
+++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
@@ -277,7 +277,8 @@ static void instantiateDependentDiagnoseIfAttr(
   if (Cond)
     New->addAttr(new (S.getASTContext()) DiagnoseIfAttr(
         S.getASTContext(), *DIA, Cond, DIA->getMessage(),
-        DIA->getDiagnosticType(), DIA->getArgDependent(), New));
+        DIA->getDefaultSeverity(), DIA->getWarningGroup(),
+        DIA->getArgDependent(), New));
 }
 
 // Constructs and adds to New a new instance of CUDALaunchBoundsAttr using
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 42b48d230af7a..acf7d9de2caf7 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -6570,7 +6570,7 @@ void ASTReader::ReadPragmaDiagnosticMappings(DiagnosticsEngine &Diag) {
       // command line (-w, -Weverything, -Werror, ...) along with any explicit
       // -Wblah flags.
       unsigned Flags = Record[Idx++];
-      DiagState Initial;
+      DiagState Initial(*Diag.getDiagnosticIDs());
       Initial.SuppressSystemWarnings = Flags & 1; Flags >>= 1;
       Initial.ErrorsAsFatal = Flags & 1; Flags >>= 1;
       Initial.WarningsAsErrors = Flags & 1; Flags >>= 1;
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 739344b9a128d..d390c70361967 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -3065,7 +3065,7 @@ void ASTWriter::WritePragmaDiagnosticMappings(const DiagnosticsEngine &Diag,
         // Skip default mappings. We have a mapping for every diagnostic ever
         // emitted, regardless of whether it was customized.
         if (!I.second.isPragma() &&
-            I.second == DiagnosticIDs::getDefaultMapping(I.first))
+            I.second == Diag.getDiagnosticIDs()->getDefaultMapping(I.first))
           continue;
         Mappings.push_back(I);
       }
diff --git a/clang/test/SemaCXX/diagnose_if-warning-group.cpp b/clang/test/SemaCXX/diagnose_if-warning-group.cpp
new file mode 100644
index 0000000000000..98bed892ba7d5
--- /dev/null
+++ b/clang/test/SemaCXX/diagnose_if-warning-group.cpp
@@ -0,0 +1,35 @@
+// RUN: %clang_cc1 %s -verify -fno-builtin -Werror=comment
+
+#define _diagnose_if(...) __attribute__((diagnose_if(__VA_ARGS__)))
+
+template <bool b>
+void diagnose_if_wcomma() _diagnose_if(b, "oh no", "warning", "comma") {}
+
+template <bool b>
+void diagnose_if_wcomment() _diagnose_if(b, "oh no", "warning", "comment") {}
+
+void bougus_warning() _diagnose_if(true, "oh no", "warning", "bougus warning") {} // expected-error {{unknown warning group}}
+
+void show_in_system_header() _diagnose_if(true, "oh no", "warning", "assume", "Banane") {} // expected-error {{'diagnose_if' attribute takes no more than 4 arguments}}
+
+void call() {
+  diagnose_if_wcomma<true>(); // expected-warning {{oh no}}
+  diagnose_if_wcomma<false>();
+  diagnose_if_wcomment<true>(); // expected-error {{oh no}}
+  diagnose_if_wcomment<false>();
+
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wcomma"
+  diagnose_if_wcomma<true>();
+  diagnose_if_wcomment<true>(); // expected-error {{oh no}}
+#pragma clang diagnostic pop
+
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wcomment"
+  diagnose_if_wcomma<true>(); // expected-warning {{oh no}}
+  diagnose_if_wcomment<true>();
+#pragma clang diagnostic pop
+
+  diagnose_if_wcomma<true>(); // expected-warning {{oh no}}
+  diagnose_if_wcomment<true>(); // expected-error {{oh no}}
+}
diff --git a/clang/tools/diagtool/ListWarnings.cpp b/clang/tools/diagtool/ListWarnings.cpp
index a71f6e3a66c8e..9f9647126dd8a 100644
--- a/clang/tools/diagtool/ListWarnings.cpp
+++ b/clang/tools/diagtool/ListWarnings.cpp
@@ -53,13 +53,13 @@ int ListWarnings::run(unsigned int argc, char **argv, llvm::raw_ostream &out) {
   for (const DiagnosticRecord &DR : getBuiltinDiagnosticsByName()) {
     const unsigned diagID = DR.DiagID;
 
-    if (DiagnosticIDs::isBuiltinNote(diagID))
+    if (DiagnosticIDs{}.isNote(diagID))
       continue;
 
-    if (!DiagnosticIDs::isBuiltinWarningOrExtension(diagID))
+    if (!DiagnosticIDs{}.isWarningOrExtension(diagID))
       continue;
 
-    Entry entry(DR.getName(), DiagnosticIDs::getWarningOptionForDiag(diagID));
+    Entry entry(DR.getName(), DiagnosticIDs{}.getWarningOptionForDiag(diagID));
 
     if (entry.Flag.empty())
       Unflagged.push_back(entry);
@@ -97,4 +97,3 @@ int ListWarnings::run(unsigned int argc, char **argv, llvm::raw_ostream &out) {
 
   return 0;
 }
-
diff --git a/clang/tools/diagtool/ShowEnabledWarnings.cpp b/clang/tools/diagtool/ShowEnabledWarnings.cpp
index 285efe6ae05b3..3d6195c2b0ab7 100644
--- a/clang/tools/diagtool/ShowEnabledWarnings.cpp
+++ b/clang/tools/diagtool/ShowEnabledWarnings.cpp
@@ -117,10 +117,10 @@ int ShowEnabledWarnings::run(unsigned int argc, char **argv, raw_ostream &Out) {
   for (const DiagnosticRecord &DR : getBuiltinDiagnosticsByName()) {
     unsigned DiagID = DR.DiagID;
 
-    if (DiagnosticIDs::isBuiltinNote(DiagID))
+    if (DiagnosticIDs{}.isNote(DiagID))
       continue;
 
-    if (!DiagnosticIDs::isBuiltinWarningOrExtension(DiagID))
+    if (!DiagnosticIDs{}.isWarningOrExtension(DiagID))
       continue;
 
     DiagnosticsEngine::Level DiagLevel =
@@ -128,7 +128,7 @@ int ShowEnabledWarnings::run(unsigned int argc, char **argv, raw_ostream &Out) {
     if (DiagLevel == DiagnosticsEngine::Ignored)
       continue;
 
-    StringRef WarningOpt = DiagnosticIDs::getWarningOptionForDiag(DiagID);
+    StringRef WarningOpt = DiagnosticIDs{}.getWarningOptionForDiag(DiagID);
     Active.push_back(PrettyDiag(DR.getName(), WarningOpt, DiagLevel));
   }
 
diff --git a/clang/tools/libclang/CXStoredDiagnostic.cpp b/clang/tools/libclang/CXStoredDiagnostic.cpp
index c4c24876e70de..f159a3a1a0a8b 100644
--- a/clang/tools/libclang/CXStoredDiagnostic.cpp
+++ b/clang/tools/libclang/CXStoredDiagnostic.cpp
@@ -33,14 +33,14 @@ CXDiagnosticSeverity CXStoredDiagnostic::getSeverity() const {
     case DiagnosticsEngine::Error:   return CXDiagnostic_Error;
     case DiagnosticsEngine::Fatal:   return CXDiagnostic_Fatal;
   }
-  
+
   llvm_unreachable("Invalid diagnostic level");
 }
 
 CXSourceLocation CXStoredDiagnostic::getLocation() const {
   if (Diag.getLocation().isInvalid())
     return clang_getNullLocation();
-  
+
   return translateSourceLocation(Diag.getLocation().getManager(),
                                  LangOpts, Diag.getLocation());
 }
@@ -51,13 +51,13 @@ CXString CXStoredDiagnostic::getSpelling() const {
 
 CXString CXStoredDiagnostic::getDiagnosticOption(CXString *Disable) const {
   unsigned ID = Diag.getID();
-  StringRef Option = DiagnosticIDs::getWarningOptionForDiag(ID);
+  StringRef Option = DiagnosticIDs{}.getWarningOptionForDiag(ID);
   if (!Option.empty()) {
     if (Disable)
       *Disable = cxstring::createDup((Twine("-Wno-") + Option).str());
     return cxstring::createDup((Twine("-W") + Option).str());
   }
-  
+
   if (ID == diag::fatal_too_many_errors) {
     if (Disable)
       *Disable = cxstring::createRef("-ferror-limit=0");
@@ -79,7 +79,7 @@ CXString CXStoredDiagnostic::getCategoryText() const {
 unsigned CXStoredDiagnostic::getNumRanges() const {
   if (Diag.getLocation().isInvalid())
     return 0;
-  
+
   return Diag.range_size();
 }
 
@@ -92,12 +92,12 @@ CXSourceRange CXStoredDiagnostic::getRange(unsigned int Range) const {
 
 unsigned CXStoredDiagnostic::getNumFixIts() const {
   if (Diag.getLocation().isInvalid())
-    return 0;    
+    return 0;
   return Diag.fixit_size();
 }
 
 CXString CXStoredDiagnostic::getFixIt(unsigned FixIt,
-                                      CXSourceRange *ReplacementRange) const {  
+                                      CXSourceRange *ReplacementRange) const {
   const FixItHint &Hint = Diag.fixit_begin()[FixIt];
   if (ReplacementRange) {
     // Create a range that covers the entire replacement (or
@@ -108,4 +108,3 @@ CXString CXStoredDiagnostic::getFixIt(unsigned FixIt,
   }
   return cxstring::createDup(Hint.CodeToInsert);
 }
-

>From 9d134737fa4c4de8c9c026090de0c360fe791b86 Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklauser at berlin.de>
Date: Sat, 29 Jun 2024 10:22:06 +0200
Subject: [PATCH 2/2] Fix downgrading diagnostics

---
 clang/lib/Basic/Diagnostic.cpp                |  2 ++
 .../SemaCXX/diagnose_if-warning-group.cpp     | 19 ++++++++++++-------
 2 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/clang/lib/Basic/Diagnostic.cpp b/clang/lib/Basic/Diagnostic.cpp
index 585d23aa36c49..4935c1f1f91ab 100644
--- a/clang/lib/Basic/Diagnostic.cpp
+++ b/clang/lib/Basic/Diagnostic.cpp
@@ -434,6 +434,7 @@ bool DiagnosticsEngine::setDiagnosticGroupWarningAsError(StringRef Group,
   if (Enabled)
     return setSeverityForGroup(diag::Flavor::WarningOrError, Group,
                                diag::Severity::Error);
+  Diags->setGroupSeverity(Group, diag::Severity::Warning);
 
   // Otherwise, we want to set the diagnostic mapping's "no Werror" bit, and
   // potentially downgrade anything already mapped to be a warning.
@@ -465,6 +466,7 @@ bool DiagnosticsEngine::setDiagnosticGroupErrorAsFatal(StringRef Group,
   if (Enabled)
     return setSeverityForGroup(diag::Flavor::WarningOrError, Group,
                                diag::Severity::Fatal);
+  Diags->setGroupSeverity(Group, diag::Severity::Error);
 
   // Otherwise, we want to set the diagnostic mapping's "no Wfatal-errors" bit,
   // and potentially downgrade anything already mapped to be a fatal error.
diff --git a/clang/test/SemaCXX/diagnose_if-warning-group.cpp b/clang/test/SemaCXX/diagnose_if-warning-group.cpp
index 35ee7c955eb9c..a01600b6971ed 100644
--- a/clang/test/SemaCXX/diagnose_if-warning-group.cpp
+++ b/clang/test/SemaCXX/diagnose_if-warning-group.cpp
@@ -1,19 +1,20 @@
-// RUN: %clang_cc1 %s -verify -fno-builtin -Werror=comment -Wno-error=abi
+// RUN: %clang_cc1 %s -verify -fno-builtin -Werror=comment -Wno-error=abi -Wfatal-errors=assume -Wno-fatal-errors=assume
 
-#define _diagnose_if(...) __attribute__((diagnose_if(__VA_ARGS__)))
+#define diagnose_if(...) __attribute__((diagnose_if(__VA_ARGS__)))
 
 template <bool b>
-void diagnose_if_wcomma() _diagnose_if(b, "oh no", "warning", "comma") {}
+void diagnose_if_wcomma() diagnose_if(b, "oh no", "warning", "comma") {}
 
 template <bool b>
-void diagnose_if_wcomment() _diagnose_if(b, "oh no", "warning", "comment") {}
+void diagnose_if_wcomment() diagnose_if(b, "oh no", "warning", "comment") {}
 
-void bougus_warning() _diagnose_if(true, "oh no", "warning", "bougus warning") {} // expected-error {{unknown warning group}}
+void bougus_warning() diagnose_if(true, "oh no", "warning", "bougus warning") {} // expected-error {{unknown warning group}}
 
-void show_in_system_header() _diagnose_if(true, "oh no", "warning", "assume", "Banane") {} // expected-error {{'diagnose_if' attribute takes no more than 4 arguments}}
+void show_in_system_header() diagnose_if(true, "oh no", "warning", "assume", "Banane") {} // expected-error {{'diagnose_if' attribute takes no more than 4 arguments}}
 
 
-void diagnose_if_wabi_default_error() _diagnose_if(true, "ABI stuff", "error", "abi") {}
+void diagnose_if_wabi_default_error() diagnose_if(true, "ABI stuff", "error", "abi") {}
+void diagnose_assume() diagnose_if(true, "Assume diagnostic", "warning", "assume") {}
 
 void call() {
   diagnose_if_wcomma<true>(); // expected-warning {{oh no}}
@@ -37,4 +38,8 @@ void call() {
   diagnose_if_wcomment<true>(); // expected-error {{oh no}}
 
   diagnose_if_wabi_default_error(); // expected-warning {{ABI stuff}}
+  diagnose_assume(); // expected-error {{Assume diagnostic}}
+
+  // Make sure that the -Wassume diagnostic isn't fatal
+  diagnose_if_wabi_default_error(); // expected-warning {{ABI stuff}}
 }



More information about the cfe-commits mailing list