[Mlir-commits] [clang] [llvm] [mlir] [TableGen] Add const variants of accessors for backend (PR #106658)

Rahul Joshi llvmlistbot at llvm.org
Thu Aug 29 21:00:06 PDT 2024


https://github.com/jurahul created https://github.com/llvm/llvm-project/pull/106658

Split RecordKeeper `getAllDerivedDefinitions` family of functions into two variants: (a) non-const ones that return vectors of `Record *` and (b) const ones, that return vectors of `const Record *`.

This will help gradual migration of backend to use `const RecordKeeper` and by implication change code to work with const pointers.

>From ed6066369b0b153f44cb089a207ae48f5efcc4f6 Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Thu, 29 Aug 2024 20:36:05 -0700
Subject: [PATCH] [TableGen] Add const variants of accessors for backend

Split RecordKeeper `getAllDerivedDefinitions` family of functions into
two variants: (a) non-const ones that return vectors of `Record *` and
(b) const ones, that return vectors of `const Record *`.

This will help gradual migration of backend to use `const RecordKeeper`
and by implication change code to work with const pointers.
---
 clang/utils/TableGen/ClangAttrEmitter.cpp     |  4 +-
 clang/utils/TableGen/ClangSyntaxEmitter.cpp   |  2 +-
 llvm/include/llvm/TableGen/DirectiveEmitter.h |  5 +-
 llvm/include/llvm/TableGen/Record.h           | 36 +++++++--
 llvm/lib/TableGen/Record.cpp                  | 74 ++++++++++++++-----
 .../TableGen/Basic/CodeGenIntrinsics.cpp      |  8 +-
 .../TableGen/Common/SubtargetFeatureInfo.cpp  |  2 +-
 .../TableGen/Common/SubtargetFeatureInfo.h    |  2 +-
 llvm/utils/TableGen/ExegesisEmitter.cpp       |  2 +-
 llvm/utils/TableGen/GlobalISelEmitter.cpp     |  2 +-
 llvm/utils/TableGen/SubtargetEmitter.cpp      |  2 +-
 llvm/utils/TableGen/TableGen.cpp              |  2 +-
 mlir/include/mlir/TableGen/GenInfo.h          |  6 +-
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp   | 28 +++----
 mlir/tools/mlir-tblgen/OmpOpGen.cpp           |  2 +-
 mlir/tools/mlir-tblgen/OpDocGen.cpp           | 20 ++---
 mlir/tools/mlir-tblgen/OpInterfacesGen.cpp    | 15 ++--
 mlir/tools/mlir-tblgen/RewriterGen.cpp        |  8 +-
 18 files changed, 138 insertions(+), 82 deletions(-)

diff --git a/clang/utils/TableGen/ClangAttrEmitter.cpp b/clang/utils/TableGen/ClangAttrEmitter.cpp
index adbe6af62d5cbe..d24215d10f17c7 100644
--- a/clang/utils/TableGen/ClangAttrEmitter.cpp
+++ b/clang/utils/TableGen/ClangAttrEmitter.cpp
@@ -189,7 +189,7 @@ static StringRef NormalizeGNUAttrSpelling(StringRef AttrSpelling) {
 
 typedef std::vector<std::pair<std::string, const Record *>> ParsedAttrMap;
 
-static ParsedAttrMap getParsedAttrList(const RecordKeeper &Records,
+static ParsedAttrMap getParsedAttrList(RecordKeeper &Records,
                                        ParsedAttrMap *Dupes = nullptr,
                                        bool SemaOnly = true) {
   std::vector<Record *> Attrs = Records.getAllDerivedDefinitions("Attr");
@@ -4344,7 +4344,7 @@ static void GenerateAppertainsTo(const Record &Attr, raw_ostream &OS) {
 // written into OS and the checks for merging declaration attributes are
 // written into MergeOS.
 static void GenerateMutualExclusionsChecks(const Record &Attr,
-                                           const RecordKeeper &Records,
+                                           RecordKeeper &Records,
                                            raw_ostream &OS,
                                            raw_ostream &MergeDeclOS,
                                            raw_ostream &MergeStmtOS) {
diff --git a/clang/utils/TableGen/ClangSyntaxEmitter.cpp b/clang/utils/TableGen/ClangSyntaxEmitter.cpp
index 9720d587318432..2a69e4c353b6b4 100644
--- a/clang/utils/TableGen/ClangSyntaxEmitter.cpp
+++ b/clang/utils/TableGen/ClangSyntaxEmitter.cpp
@@ -41,7 +41,7 @@ using llvm::formatv;
 // stable and useful way, where abstract Node subclasses correspond to ranges.
 class Hierarchy {
 public:
-  Hierarchy(const llvm::RecordKeeper &Records) {
+  Hierarchy(llvm::RecordKeeper &Records) {
     for (llvm::Record *T : Records.getAllDerivedDefinitions("NodeType"))
       add(T);
     for (llvm::Record *Derived : Records.getAllDerivedDefinitions("NodeType"))
diff --git a/llvm/include/llvm/TableGen/DirectiveEmitter.h b/llvm/include/llvm/TableGen/DirectiveEmitter.h
index 1121459be6ce7d..ca21c8fc101450 100644
--- a/llvm/include/llvm/TableGen/DirectiveEmitter.h
+++ b/llvm/include/llvm/TableGen/DirectiveEmitter.h
@@ -15,8 +15,7 @@ namespace llvm {
 // DirectiveBase.td and provides helper methods for accessing it.
 class DirectiveLanguage {
 public:
-  explicit DirectiveLanguage(const llvm::RecordKeeper &Records)
-      : Records(Records) {
+  explicit DirectiveLanguage(llvm::RecordKeeper &Records) : Records(Records) {
     const auto &DirectiveLanguages = getDirectiveLanguages();
     Def = DirectiveLanguages[0];
   }
@@ -71,7 +70,7 @@ class DirectiveLanguage {
 
 private:
   const llvm::Record *Def;
-  const llvm::RecordKeeper &Records;
+  llvm::RecordKeeper &Records;
 
   std::vector<Record *> getDirectiveLanguages() const {
     return Records.getAllDerivedDefinitions("DirectiveLanguage");
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index a339946e67cf2d..667a85b871f834 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -264,7 +264,7 @@ class RecordRecTy final : public RecTy, public FoldingSetNode,
 
   std::string getAsString() const override;
 
-  bool isSubClassOf(Record *Class) const;
+  bool isSubClassOf(const Record *Class) const;
   bool typeIsConvertibleTo(const RecTy *RHS) const override;
 
   bool typeIsA(const RecTy *RHS) const override;
@@ -2057,19 +2057,28 @@ class RecordKeeper {
   //===--------------------------------------------------------------------===//
   // High-level helper methods, useful for tablegen backends.
 
+  // Non-const methods return std::vector<Record *> by value or reference.
+  // Const methods return std::vector<const Record *> by value or reference.
+
   /// Get all the concrete records that inherit from the one specified
   /// class. The class must be defined.
-  std::vector<Record *> getAllDerivedDefinitions(StringRef ClassName) const;
+  const std::vector<const Record *> &
+  getAllDerivedDefinitions(StringRef ClassName) const;
+  const std::vector<Record *> &getAllDerivedDefinitions(StringRef ClassName);
 
   /// Get all the concrete records that inherit from all the specified
   /// classes. The classes must be defined.
-  std::vector<Record *> getAllDerivedDefinitions(
-      ArrayRef<StringRef> ClassNames) const;
+  std::vector<const Record *>
+  getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames) const;
+  std::vector<Record *>
+  getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames);
 
   /// Get all the concrete records that inherit from specified class, if the
   /// class is defined. Returns an empty vector if the class is not defined.
-  std::vector<Record *>
+  const std::vector<const Record *> &
   getAllDerivedDefinitionsIfDefined(StringRef ClassName) const;
+  const std::vector<Record *> &
+  getAllDerivedDefinitionsIfDefined(StringRef ClassName);
 
   void dump() const;
 
@@ -2079,9 +2088,24 @@ class RecordKeeper {
   RecordKeeper &operator=(RecordKeeper &&) = delete;
   RecordKeeper &operator=(const RecordKeeper &) = delete;
 
+  // Helper template functions for backend accessors.
+  template <typename VecTy>
+  const VecTy &
+  getAllDerivedDefinitionsImpl(StringRef ClassName,
+                               std::map<std::string, VecTy> &Cache) const;
+
+  template <typename VecTy>
+  VecTy getAllDerivedDefinitionsImpl(ArrayRef<StringRef> ClassNames) const;
+
+  template <typename VecTy>
+  const VecTy &getAllDerivedDefinitionsIfDefinedImpl(
+      StringRef ClassName, std::map<std::string, VecTy> &Cache) const;
+
   std::string InputFilename;
   RecordMap Classes, Defs;
-  mutable StringMap<std::vector<Record *>> ClassRecordsMap;
+  mutable std::map<std::string, std::vector<const Record *>>
+      ClassRecordsMapConst;
+  mutable std::map<std::string, std::vector<Record *>> ClassRecordsMap;
   GlobalMap ExtraGlobals;
 
   // These members are for the phase timing feature. We need a timer group,
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index bcecee8e550c88..22e2163e3d4b17 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -266,11 +266,10 @@ std::string RecordRecTy::getAsString() const {
   return Str;
 }
 
-bool RecordRecTy::isSubClassOf(Record *Class) const {
-  return llvm::any_of(getClasses(), [Class](Record *MySuperClass) {
-                                      return MySuperClass == Class ||
-                                             MySuperClass->isSubClassOf(Class);
-                                    });
+bool RecordRecTy::isSubClassOf(const Record *Class) const {
+  return llvm::any_of(getClasses(), [Class](const Record *MySuperClass) {
+    return MySuperClass == Class || MySuperClass->isSubClassOf(Class);
+  });
 }
 
 bool RecordRecTy::typeIsConvertibleTo(const RecTy *RHS) const {
@@ -3219,25 +3218,28 @@ void RecordKeeper::stopBackendTimer() {
   }
 }
 
-std::vector<Record *>
-RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) const {
+template <typename VecTy>
+const VecTy &RecordKeeper::getAllDerivedDefinitionsImpl(
+    StringRef ClassName, std::map<std::string, VecTy> &Cache) const {
   // We cache the record vectors for single classes. Many backends request
   // the same vectors multiple times.
-  auto Pair = ClassRecordsMap.try_emplace(ClassName);
+  auto Pair = Cache.try_emplace(ClassName.str());
   if (Pair.second)
-    Pair.first->second = getAllDerivedDefinitions(ArrayRef(ClassName));
+    Pair.first->second =
+        getAllDerivedDefinitionsImpl<VecTy>(ArrayRef(ClassName));
 
   return Pair.first->second;
 }
 
-std::vector<Record *> RecordKeeper::getAllDerivedDefinitions(
+template <typename VecTy>
+VecTy RecordKeeper::getAllDerivedDefinitionsImpl(
     ArrayRef<StringRef> ClassNames) const {
-  SmallVector<Record *, 2> ClassRecs;
-  std::vector<Record *> Defs;
+  SmallVector<const Record *, 2> ClassRecs;
+  VecTy Defs;
 
   assert(ClassNames.size() > 0 && "At least one class must be passed.");
   for (const auto &ClassName : ClassNames) {
-    Record *Class = getClass(ClassName);
+    const Record *Class = getClass(ClassName);
     if (!Class)
       PrintFatalError("The class '" + ClassName + "' is not defined\n");
     ClassRecs.push_back(Class);
@@ -3245,20 +3247,54 @@ std::vector<Record *> RecordKeeper::getAllDerivedDefinitions(
 
   for (const auto &OneDef : getDefs()) {
     if (all_of(ClassRecs, [&OneDef](const Record *Class) {
-                            return OneDef.second->isSubClassOf(Class);
-                          }))
+          return OneDef.second->isSubClassOf(Class);
+        }))
       Defs.push_back(OneDef.second.get());
   }
-
   llvm::sort(Defs, LessRecord());
-
   return Defs;
 }
 
+template <typename VecTy>
+const VecTy &RecordKeeper::getAllDerivedDefinitionsIfDefinedImpl(
+    StringRef ClassName, std::map<std::string, VecTy> &Cache) const {
+  return getClass(ClassName)
+             ? getAllDerivedDefinitionsImpl<VecTy>(ClassName, Cache)
+             : Cache[""];
+}
+
+const std::vector<const Record *> &
+RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) const {
+  return getAllDerivedDefinitionsImpl<std::vector<const Record *>>(
+      ClassName, ClassRecordsMapConst);
+}
+
+const std::vector<Record *> &
+RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) {
+  return getAllDerivedDefinitionsImpl<std::vector<Record *>>(ClassName,
+                                                             ClassRecordsMap);
+}
+
+std::vector<const Record *>
+RecordKeeper::getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames) const {
+  return getAllDerivedDefinitionsImpl<std::vector<const Record *>>(ClassNames);
+}
+
 std::vector<Record *>
+RecordKeeper::getAllDerivedDefinitions(ArrayRef<StringRef> ClassNames) {
+  return getAllDerivedDefinitionsImpl<std::vector<Record *>>(ClassNames);
+}
+
+const std::vector<const Record *> &
 RecordKeeper::getAllDerivedDefinitionsIfDefined(StringRef ClassName) const {
-  return getClass(ClassName) ? getAllDerivedDefinitions(ClassName)
-                             : std::vector<Record *>();
+  return getAllDerivedDefinitionsIfDefinedImpl<std::vector<const Record *>>(
+      ClassName, ClassRecordsMapConst);
+}
+
+const std::vector<Record *> &
+RecordKeeper::getAllDerivedDefinitionsIfDefined(StringRef ClassName) {
+  return getAllDerivedDefinitionsIfDefinedImpl<std::vector<Record *>>(
+      ClassName, ClassRecordsMap);
 }
 
 Init *MapResolver::resolve(Init *VarName) {
diff --git a/llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp b/llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp
index 4bca904e9f38bd..6ffe788bcebca2 100644
--- a/llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp
+++ b/llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp
@@ -26,15 +26,13 @@ using namespace llvm;
 //===----------------------------------------------------------------------===//
 
 CodeGenIntrinsicTable::CodeGenIntrinsicTable(const RecordKeeper &RC) {
-  std::vector<Record *> IntrProperties =
-      RC.getAllDerivedDefinitions("IntrinsicProperty");
-
   std::vector<const Record *> DefaultProperties;
-  for (const Record *Rec : IntrProperties)
+  for (const Record *Rec : RC.getAllDerivedDefinitions("IntrinsicProperty"))
     if (Rec->getValueAsBit("IsDefault"))
       DefaultProperties.push_back(Rec);
 
-  std::vector<Record *> Defs = RC.getAllDerivedDefinitions("Intrinsic");
+  const std::vector<const Record *> &Defs =
+      RC.getAllDerivedDefinitions("Intrinsic");
   Intrinsics.reserve(Defs.size());
 
   for (const Record *Def : Defs)
diff --git a/llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp b/llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp
index 4f57234d6fe275..a4d6d8d21b3562 100644
--- a/llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp
+++ b/llvm/utils/TableGen/Common/SubtargetFeatureInfo.cpp
@@ -21,7 +21,7 @@ LLVM_DUMP_METHOD void SubtargetFeatureInfo::dump() const {
 #endif
 
 std::vector<std::pair<Record *, SubtargetFeatureInfo>>
-SubtargetFeatureInfo::getAll(const RecordKeeper &Records) {
+SubtargetFeatureInfo::getAll(RecordKeeper &Records) {
   std::vector<std::pair<Record *, SubtargetFeatureInfo>> SubtargetFeatures;
   std::vector<Record *> AllPredicates =
       Records.getAllDerivedDefinitions("Predicate");
diff --git a/llvm/utils/TableGen/Common/SubtargetFeatureInfo.h b/llvm/utils/TableGen/Common/SubtargetFeatureInfo.h
index 2635e4b733e1a3..fee2c0263c4960 100644
--- a/llvm/utils/TableGen/Common/SubtargetFeatureInfo.h
+++ b/llvm/utils/TableGen/Common/SubtargetFeatureInfo.h
@@ -49,7 +49,7 @@ struct SubtargetFeatureInfo {
 
   void dump() const;
   static std::vector<std::pair<Record *, SubtargetFeatureInfo>>
-  getAll(const RecordKeeper &Records);
+  getAll(RecordKeeper &Records);
 
   /// Emit the subtarget feature flag definitions.
   ///
diff --git a/llvm/utils/TableGen/ExegesisEmitter.cpp b/llvm/utils/TableGen/ExegesisEmitter.cpp
index d48c7f3a480f24..0de7cb42337481 100644
--- a/llvm/utils/TableGen/ExegesisEmitter.cpp
+++ b/llvm/utils/TableGen/ExegesisEmitter.cpp
@@ -59,7 +59,7 @@ class ExegesisEmitter {
 };
 
 static std::map<llvm::StringRef, unsigned>
-collectPfmCounters(const RecordKeeper &Records) {
+collectPfmCounters(RecordKeeper &Records) {
   std::map<llvm::StringRef, unsigned> PfmCounterNameTable;
   const auto AddPfmCounterName = [&PfmCounterNameTable](
                                      const Record *PfmCounterDef) {
diff --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp
index a491a049e7c812..2606768c0c582c 100644
--- a/llvm/utils/TableGen/GlobalISelEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp
@@ -335,7 +335,7 @@ class GlobalISelEmitter final : public GlobalISelMatchTableExecutorEmitter {
 private:
   std::string ClassName;
 
-  const RecordKeeper &RK;
+  RecordKeeper &RK;
   const CodeGenDAGPatterns CGP;
   const CodeGenTarget &Target;
   CodeGenRegBank &CGRegs;
diff --git a/llvm/utils/TableGen/SubtargetEmitter.cpp b/llvm/utils/TableGen/SubtargetEmitter.cpp
index 66ca38ee5ae2f8..7ae61cb7c446b1 100644
--- a/llvm/utils/TableGen/SubtargetEmitter.cpp
+++ b/llvm/utils/TableGen/SubtargetEmitter.cpp
@@ -1545,7 +1545,7 @@ void SubtargetEmitter::EmitSchedModel(raw_ostream &OS) {
   EmitProcessorModels(OS);
 }
 
-static void emitPredicateProlog(const RecordKeeper &Records, raw_ostream &OS) {
+static void emitPredicateProlog(RecordKeeper &Records, raw_ostream &OS) {
   std::string Buffer;
   raw_string_ostream Stream(Buffer);
 
diff --git a/llvm/utils/TableGen/TableGen.cpp b/llvm/utils/TableGen/TableGen.cpp
index 7ee6fa5c832114..882410bac081b4 100644
--- a/llvm/utils/TableGen/TableGen.cpp
+++ b/llvm/utils/TableGen/TableGen.cpp
@@ -52,7 +52,7 @@ static void PrintEnums(RecordKeeper &Records, raw_ostream &OS) {
 static void PrintSets(const RecordKeeper &Records, raw_ostream &OS) {
   SetTheory Sets;
   Sets.addFieldExpander("Set", "Elements");
-  for (Record *Rec : Records.getAllDerivedDefinitions("Set")) {
+  for (const Record *Rec : Records.getAllDerivedDefinitions("Set")) {
     OS << Rec->getName() << " = [";
     const std::vector<Record *> *Elts = Sets.expand(Rec);
     assert(Elts && "Couldn't expand Set instance");
diff --git a/mlir/include/mlir/TableGen/GenInfo.h b/mlir/include/mlir/TableGen/GenInfo.h
index ef2e12f07df16d..d59d64223827bd 100644
--- a/mlir/include/mlir/TableGen/GenInfo.h
+++ b/mlir/include/mlir/TableGen/GenInfo.h
@@ -21,8 +21,8 @@ class RecordKeeper;
 namespace mlir {
 
 /// Generator function to invoke.
-using GenFunction = std::function<bool(const llvm::RecordKeeper &recordKeeper,
-                                       raw_ostream &os)>;
+using GenFunction =
+    std::function<bool(llvm::RecordKeeper &recordKeeper, raw_ostream &os)>;
 
 /// Structure to group information about a generator (argument to invoke via
 /// mlir-tblgen, description, and generator function).
@@ -34,7 +34,7 @@ class GenInfo {
       : arg(arg), description(description), generator(std::move(generator)) {}
 
   /// Invokes the generator and returns whether the generator failed.
-  bool invoke(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
+  bool invoke(llvm::RecordKeeper &recordKeeper, raw_ostream &os) const {
     assert(generator && "Cannot call generator with null generator");
     return generator(recordKeeper, os);
   }
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index eccd8029d950ff..feca04bff643d5 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -690,10 +690,10 @@ class DefGenerator {
   bool emitDefs(StringRef selectedDialect);
 
 protected:
-  DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os,
+  DefGenerator(const std::vector<llvm::Record *> &defs, raw_ostream &os,
                StringRef defType, StringRef valueType, bool isAttrGenerator)
-      : defRecords(std::move(defs)), os(os), defType(defType),
-        valueType(valueType), isAttrGenerator(isAttrGenerator) {
+      : defRecords(defs), os(os), defType(defType), valueType(valueType),
+        isAttrGenerator(isAttrGenerator) {
     // Sort by occurrence in file.
     llvm::sort(defRecords, [](llvm::Record *lhs, llvm::Record *rhs) {
       return lhs->getID() < rhs->getID();
@@ -721,13 +721,13 @@ class DefGenerator {
 
 /// A specialized generator for AttrDefs.
 struct AttrDefGenerator : public DefGenerator {
-  AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
+  AttrDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
       : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
                      "Attr", "Attribute", /*isAttrGenerator=*/true) {}
 };
 /// A specialized generator for TypeDefs.
 struct TypeDefGenerator : public DefGenerator {
-  TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
+  TypeDefGenerator(llvm::RecordKeeper &records, raw_ostream &os)
       : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
                      "Type", "Type", /*isAttrGenerator=*/false) {}
 };
@@ -1029,7 +1029,7 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
 
 /// Find all type constraints for which a C++ function should be generated.
 static std::vector<Constraint>
-getAllTypeConstraints(const llvm::RecordKeeper &records) {
+getAllTypeConstraints(llvm::RecordKeeper &records) {
   std::vector<Constraint> result;
   for (llvm::Record *def :
        records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
@@ -1046,7 +1046,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
   return result;
 }
 
-static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
+static void emitTypeConstraintDecls(llvm::RecordKeeper &records,
                                     raw_ostream &os) {
   static const char *const typeConstraintDecl = R"(
 bool {0}(::mlir::Type type);
@@ -1056,7 +1056,7 @@ bool {0}(::mlir::Type type);
     os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
 }
 
-static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
+static void emitTypeConstraintDefs(llvm::RecordKeeper &records,
                                    raw_ostream &os) {
   static const char *const typeConstraintDef = R"(
 bool {0}(::mlir::Type type) {
@@ -1087,13 +1087,13 @@ static llvm::cl::opt<std::string>
 
 static mlir::GenRegistration
     genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
-                [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                [](llvm::RecordKeeper &records, raw_ostream &os) {
                   AttrDefGenerator generator(records, os);
                   return generator.emitDefs(attrDialect);
                 });
 static mlir::GenRegistration
     genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
-                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                 [](llvm::RecordKeeper &records, raw_ostream &os) {
                    AttrDefGenerator generator(records, os);
                    return generator.emitDecls(attrDialect);
                  });
@@ -1109,13 +1109,13 @@ static llvm::cl::opt<std::string>
 
 static mlir::GenRegistration
     genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
-                [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                [](llvm::RecordKeeper &records, raw_ostream &os) {
                   TypeDefGenerator generator(records, os);
                   return generator.emitDefs(typeDialect);
                 });
 static mlir::GenRegistration
     genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
-                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                 [](llvm::RecordKeeper &records, raw_ostream &os) {
                    TypeDefGenerator generator(records, os);
                    return generator.emitDecls(typeDialect);
                  });
@@ -1123,14 +1123,14 @@ static mlir::GenRegistration
 static mlir::GenRegistration
     genTypeConstrDefs("gen-type-constraint-defs",
                       "Generate type constraint definitions",
-                      [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                      [](llvm::RecordKeeper &records, raw_ostream &os) {
                         emitTypeConstraintDefs(records, os);
                         return false;
                       });
 static mlir::GenRegistration
     genTypeConstrDecls("gen-type-constraint-decls",
                        "Generate type constraint declarations",
-                       [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                       [](llvm::RecordKeeper &records, raw_ostream &os) {
                          emitTypeConstraintDecls(records, os);
                          return false;
                        });
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index ffa2e17cc8f916..b7f6ca975a9a34 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -149,7 +149,7 @@ static void verifyClause(Record *op, Record *clause) {
 
 /// Verify that all properties of `OpenMP_Clause`s of records deriving from
 /// `OpenMP_Op`s have been inherited by the latter.
-static bool verifyDecls(const RecordKeeper &recordKeeper, raw_ostream &) {
+static bool verifyDecls(RecordKeeper &recordKeeper, raw_ostream &) {
   for (Record *op : recordKeeper.getAllDerivedDefinitions("OpenMP_Op")) {
     for (Record *clause : op->getValueAsListOfDefs("clauseList"))
       verifyClause(op, clause);
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index 71df80cd110f15..066e5b24f5a3c1 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -282,7 +282,7 @@ static void emitSourceLink(StringRef inputFilename, raw_ostream &os) {
      << inputFromMlirInclude << ")\n\n";
 }
 
-static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
+static void emitOpDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
   auto opDefs = getRequestedOpDefinitions(recordKeeper);
 
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
@@ -371,8 +371,8 @@ static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
   os << "\n";
 }
 
-static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
-                                 raw_ostream &os, StringRef recordTypeName) {
+static void emitAttrOrTypeDefDoc(RecordKeeper &recordKeeper, raw_ostream &os,
+                                 StringRef recordTypeName) {
   std::vector<llvm::Record *> defs =
       recordKeeper.getAllDerivedDefinitions(recordTypeName);
 
@@ -405,7 +405,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
   os << "\n";
 }
 
-static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
+static void emitEnumDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
   std::vector<llvm::Record *> defs =
       recordKeeper.getAllDerivedDefinitions("EnumAttr");
 
@@ -518,7 +518,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
             os);
 }
 
-static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
+static bool emitDialectDoc(RecordKeeper &recordKeeper, raw_ostream &os) {
   std::vector<Record *> dialectDefs =
       recordKeeper.getAllDerivedDefinitionsIfDefined("Dialect");
   SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
@@ -617,34 +617,34 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
 static mlir::GenRegistration
     genAttrRegister("gen-attrdef-doc",
                     "Generate dialect attribute documentation",
-                    [](const RecordKeeper &records, raw_ostream &os) {
+                    [](RecordKeeper &records, raw_ostream &os) {
                       emitAttrOrTypeDefDoc(records, os, "AttrDef");
                       return false;
                     });
 
 static mlir::GenRegistration
     genOpRegister("gen-op-doc", "Generate dialect documentation",
-                  [](const RecordKeeper &records, raw_ostream &os) {
+                  [](RecordKeeper &records, raw_ostream &os) {
                     emitOpDoc(records, os);
                     return false;
                   });
 
 static mlir::GenRegistration
     genTypeRegister("gen-typedef-doc", "Generate dialect type documentation",
-                    [](const RecordKeeper &records, raw_ostream &os) {
+                    [](RecordKeeper &records, raw_ostream &os) {
                       emitAttrOrTypeDefDoc(records, os, "TypeDef");
                       return false;
                     });
 
 static mlir::GenRegistration
     genEnumRegister("gen-enum-doc", "Generate dialect enum documentation",
-                    [](const RecordKeeper &records, raw_ostream &os) {
+                    [](RecordKeeper &records, raw_ostream &os) {
                       emitEnumDoc(records, os);
                       return false;
                     });
 
 static mlir::GenRegistration
     genRegister("gen-dialect-doc", "Generate dialect documentation",
-                [](const RecordKeeper &records, raw_ostream &os) {
+                [](RecordKeeper &records, raw_ostream &os) {
                   return emitDialectDoc(records, os);
                 });
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 4b06b92fbc8a8e..00f21a1cefbdd8 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -62,8 +62,7 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method,
 /// Get an array of all OpInterface definitions but exclude those subclassing
 /// "DeclareOpInterfaceMethods".
 static std::vector<llvm::Record *>
-getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
-                           StringRef name) {
+getAllInterfaceDefinitions(llvm::RecordKeeper &recordKeeper, StringRef name) {
   std::vector<llvm::Record *> defs =
       recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
 
@@ -118,7 +117,7 @@ class InterfaceGenerator {
 
 /// A specialized generator for attribute interfaces.
 struct AttrInterfaceGenerator : public InterfaceGenerator {
-  AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
+  AttrInterfaceGenerator(llvm::RecordKeeper &records, raw_ostream &os)
       : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
     valueType = "::mlir::Attribute";
     interfaceBaseType = "AttributeInterface";
@@ -133,7 +132,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
 };
 /// A specialized generator for operation interfaces.
 struct OpInterfaceGenerator : public InterfaceGenerator {
-  OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
+  OpInterfaceGenerator(llvm::RecordKeeper &records, raw_ostream &os)
       : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
     valueType = "::mlir::Operation *";
     interfaceBaseType = "OpInterface";
@@ -149,7 +148,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
 };
 /// A specialized generator for type interfaces.
 struct TypeInterfaceGenerator : public InterfaceGenerator {
-  TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
+  TypeInterfaceGenerator(llvm::RecordKeeper &records, raw_ostream &os)
       : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
     valueType = "::mlir::Type";
     interfaceBaseType = "TypeInterface";
@@ -684,15 +683,15 @@ struct InterfaceGenRegistration {
         genDefDesc(("Generate " + genDesc + " interface definitions").str()),
         genDocDesc(("Generate " + genDesc + " interface documentation").str()),
         genDecls(genDeclArg, genDeclDesc,
-                 [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                 [](llvm::RecordKeeper &records, raw_ostream &os) {
                    return GeneratorT(records, os).emitInterfaceDecls();
                  }),
         genDefs(genDefArg, genDefDesc,
-                [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                [](llvm::RecordKeeper &records, raw_ostream &os) {
                   return GeneratorT(records, os).emitInterfaceDefs();
                 }),
         genDocs(genDocArg, genDocDesc,
-                [](const llvm::RecordKeeper &records, raw_ostream &os) {
+                [](llvm::RecordKeeper &records, raw_ostream &os) {
                   return GeneratorT(records, os).emitInterfaceDocs();
                 }) {}
 
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 2c79ba2cd6353e..401f02246ed235 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -268,7 +268,7 @@ class PatternEmitter {
 // inlining them.
 class StaticMatcherHelper {
 public:
-  StaticMatcherHelper(raw_ostream &os, const RecordKeeper &recordKeeper,
+  StaticMatcherHelper(raw_ostream &os, RecordKeeper &recordKeeper,
                       RecordOperatorMap &mapper);
 
   // Determine if we should inline the match logic or delegate to a static
@@ -1886,7 +1886,7 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs(
 }
 
 StaticMatcherHelper::StaticMatcherHelper(raw_ostream &os,
-                                         const RecordKeeper &recordKeeper,
+                                         RecordKeeper &recordKeeper,
                                          RecordOperatorMap &mapper)
     : opMap(mapper), staticVerifierEmitter(os, recordKeeper) {}
 
@@ -1951,7 +1951,7 @@ StringRef StaticMatcherHelper::getVerifierName(DagLeaf leaf) {
   return staticVerifierEmitter.getTypeConstraintFn(leaf.getAsConstraint());
 }
 
-static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
+static void emitRewriters(RecordKeeper &recordKeeper, raw_ostream &os) {
   emitSourceFileHeader("Rewriters", os, recordKeeper);
 
   const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern");
@@ -2001,7 +2001,7 @@ static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {
 
 static mlir::GenRegistration
     genRewriters("gen-rewriters", "Generate pattern rewriters",
-                 [](const RecordKeeper &records, raw_ostream &os) {
+                 [](RecordKeeper &records, raw_ostream &os) {
                    emitRewriters(records, os);
                    return false;
                  });



More information about the Mlir-commits mailing list