[Mlir-commits] [mlir] [MLIR][TableGen] Use const pointers for various Init classes. (PR #112316)

Rahul Joshi llvmlistbot at llvm.org
Mon Oct 14 23:28:16 PDT 2024


https://github.com/jurahul created https://github.com/llvm/llvm-project/pull/112316

Use const pointers for various Init classes.

>From 6958b2433fed41d37a509d40180a8ae173c2a48a Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjoshi at nvidia.com>
Date: Mon, 14 Oct 2024 23:25:48 -0700
Subject: [PATCH] [MLIR][TableGen] Use const Init pointer

---
 mlir/include/mlir/TableGen/AttrOrTypeDef.h    |  2 +-
 mlir/include/mlir/TableGen/Dialect.h          |  2 +-
 mlir/include/mlir/TableGen/Operator.h         | 15 ++++----
 mlir/lib/TableGen/AttrOrTypeDef.cpp           | 12 +++---
 mlir/lib/TableGen/Attribute.cpp               |  2 +-
 mlir/lib/TableGen/Dialect.cpp                 |  2 +-
 mlir/lib/TableGen/Interfaces.cpp              |  6 +--
 mlir/lib/TableGen/Operator.cpp                | 21 +++++-----
 mlir/lib/TableGen/Pattern.cpp                 |  2 +-
 mlir/lib/TableGen/Type.cpp                    |  2 +-
 mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp | 16 ++++----
 mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp | 38 ++++++++++---------
 mlir/tools/mlir-tblgen/DialectGen.cpp         |  9 +++--
 mlir/tools/mlir-tblgen/OmpOpGen.cpp           | 19 ++++++----
 14 files changed, 80 insertions(+), 68 deletions(-)

diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 36744c85bc7086..c3d730e42ef70e 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -105,7 +105,7 @@ class AttrOrTypeParameter {
   std::optional<StringRef> getDefaultValue() const;
 
   /// Return the underlying def of this parameter.
-  llvm::Init *getDef() const;
+  const llvm::Init *getDef() const;
 
   /// The parameter is pointer-comparable.
   bool operator==(const AttrOrTypeParameter &other) const {
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 3530d240c976c6..ea8f40555e4451 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -92,7 +92,7 @@ class Dialect {
   /// dialect.
   bool usePropertiesForAttributes() const;
 
-  llvm::DagInit *getDiscardableAttributes() const;
+  const llvm::DagInit *getDiscardableAttributes() const;
 
   const llvm::Record *getDef() const { return def; }
 
diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h
index 768291a3a7267b..9e570373d9cd32 100644
--- a/mlir/include/mlir/TableGen/Operator.h
+++ b/mlir/include/mlir/TableGen/Operator.h
@@ -119,14 +119,15 @@ class Operator {
 
   /// A utility iterator over a list of variable decorators.
   struct VariableDecoratorIterator
-      : public llvm::mapped_iterator<llvm::Init *const *,
-                                     VariableDecorator (*)(llvm::Init *)> {
+      : public llvm::mapped_iterator<const llvm::Init *const *,
+                                     VariableDecorator (*)(
+                                         const llvm::Init *)> {
     /// Initializes the iterator to the specified iterator.
-    VariableDecoratorIterator(llvm::Init *const *it)
-        : llvm::mapped_iterator<llvm::Init *const *,
-                                VariableDecorator (*)(llvm::Init *)>(it,
-                                                                     &unwrap) {}
-    static VariableDecorator unwrap(llvm::Init *init);
+    VariableDecoratorIterator(const llvm::Init *const *it)
+        : llvm::mapped_iterator<const llvm::Init *const *,
+                                VariableDecorator (*)(const llvm::Init *)>(
+              it, &unwrap) {}
+    static VariableDecorator unwrap(const llvm::Init *init);
   };
   using var_decorator_iterator = VariableDecoratorIterator;
   using var_decorator_range = llvm::iterator_range<VariableDecoratorIterator>;
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 9b9d9fd2317d99..e72ca155bcf765 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -40,7 +40,7 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
   auto *builderList =
       dyn_cast_or_null<llvm::ListInit>(def->getValueInit("builders"));
   if (builderList && !builderList->empty()) {
-    for (llvm::Init *init : builderList->getValues()) {
+    for (const llvm::Init *init : builderList->getValues()) {
       AttrOrTypeBuilder builder(cast<llvm::DefInit>(init)->getDef(),
                                 def->getLoc());
 
@@ -58,8 +58,8 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
   if (auto *traitList = def->getValueAsListInit("traits")) {
     SmallPtrSet<const llvm::Init *, 32> traitSet;
     traits.reserve(traitSet.size());
-    llvm::unique_function<void(llvm::ListInit *)> processTraitList =
-        [&](llvm::ListInit *traitList) {
+    llvm::unique_function<void(const llvm::ListInit *)> processTraitList =
+        [&](const llvm::ListInit *traitList) {
           for (auto *traitInit : *traitList) {
             if (!traitSet.insert(traitInit).second)
               continue;
@@ -335,7 +335,9 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
   return result && !result->empty() ? result : std::nullopt;
 }
 
-llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
+const llvm::Init *AttrOrTypeParameter::getDef() const {
+  return def->getArg(index);
+}
 
 std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
   if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
@@ -349,7 +351,7 @@ std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
 //===----------------------------------------------------------------------===//
 
 bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
-  llvm::Init *paramDef = param->getDef();
+  const llvm::Init *paramDef = param->getDef();
   if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
     return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
   return false;
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index de930cb4007032..887553bca66102 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -126,7 +126,7 @@ StringRef Attribute::getDerivedCodeBody() const {
 Dialect Attribute::getDialect() const {
   const llvm::RecordVal *record = def->getValue("dialect");
   if (record && record->getValue()) {
-    if (DefInit *init = dyn_cast<DefInit>(record->getValue()))
+    if (const DefInit *init = dyn_cast<DefInit>(record->getValue()))
       return Dialect(init->getDef());
   }
   return Dialect(nullptr);
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 081f6e56f9ded4..ef39818e439b3e 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -106,7 +106,7 @@ bool Dialect::usePropertiesForAttributes() const {
   return def->getValueAsBit("usePropertiesForAttributes");
 }
 
-llvm::DagInit *Dialect::getDiscardableAttributes() const {
+const llvm::DagInit *Dialect::getDiscardableAttributes() const {
   return def->getValueAsDag("discardableAttrs");
 }
 
diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index a209b003b0f3bb..4a6709a43d0a8f 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -22,7 +22,7 @@ using namespace mlir::tblgen;
 //===----------------------------------------------------------------------===//
 
 InterfaceMethod::InterfaceMethod(const llvm::Record *def) : def(def) {
-  llvm::DagInit *args = def->getValueAsDag("arguments");
+  const llvm::DagInit *args = def->getValueAsDag("arguments");
   for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
     arguments.push_back(
         {llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
@@ -78,7 +78,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
 
   // Initialize the interface methods.
   auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
-  for (llvm::Init *init : listInit->getValues())
+  for (const llvm::Init *init : listInit->getValues())
     methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
 
   // Initialize the interface base classes.
@@ -98,7 +98,7 @@ Interface::Interface(const llvm::Record *def) : def(def) {
         baseInterfaces.push_back(std::make_unique<Interface>(baseInterface));
         basesAdded.insert(baseInterface.getName());
       };
-  for (llvm::Init *init : basesInit->getValues())
+  for (const llvm::Init *init : basesInit->getValues())
     addBaseInterfaceFn(Interface(cast<llvm::DefInit>(init)->getDef()));
 }
 
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 6a33ff5ecd6721..86670e9f87127c 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -161,7 +161,7 @@ std::string Operator::getQualCppClassName() const {
 StringRef Operator::getCppNamespace() const { return cppNamespace; }
 
 int Operator::getNumResults() const {
-  DagInit *results = def.getValueAsDag("results");
+  const DagInit *results = def.getValueAsDag("results");
   return results->getNumArgs();
 }
 
@@ -198,12 +198,12 @@ auto Operator::getResults() const -> const_value_range {
 }
 
 TypeConstraint Operator::getResultTypeConstraint(int index) const {
-  DagInit *results = def.getValueAsDag("results");
+  const DagInit *results = def.getValueAsDag("results");
   return TypeConstraint(cast<DefInit>(results->getArg(index)));
 }
 
 StringRef Operator::getResultName(int index) const {
-  DagInit *results = def.getValueAsDag("results");
+  const DagInit *results = def.getValueAsDag("results");
   return results->getArgNameStr(index);
 }
 
@@ -241,7 +241,7 @@ Operator::arg_range Operator::getArgs() const {
 }
 
 StringRef Operator::getArgName(int index) const {
-  DagInit *argumentValues = def.getValueAsDag("arguments");
+  const DagInit *argumentValues = def.getValueAsDag("arguments");
   return argumentValues->getArgNameStr(index);
 }
 
@@ -557,7 +557,7 @@ void Operator::populateOpStructure() {
   auto *opVarClass = recordKeeper.getClass("OpVariable");
   numNativeAttributes = 0;
 
-  DagInit *argumentValues = def.getValueAsDag("arguments");
+  const DagInit *argumentValues = def.getValueAsDag("arguments");
   unsigned numArgs = argumentValues->getNumArgs();
 
   // Mapping from name of to argument or result index. Arguments are indexed
@@ -721,8 +721,8 @@ void Operator::populateOpStructure() {
                   " to precede it in traits list");
     };
 
-    std::function<void(llvm::ListInit *)> insert;
-    insert = [&](llvm::ListInit *traitList) {
+    std::function<void(const llvm::ListInit *)> insert;
+    insert = [&](const llvm::ListInit *traitList) {
       for (auto *traitInit : *traitList) {
         auto *def = cast<DefInit>(traitInit)->getDef();
         if (def->isSubClassOf("TraitList")) {
@@ -780,7 +780,7 @@ void Operator::populateOpStructure() {
   auto *builderList =
       dyn_cast_or_null<llvm::ListInit>(def.getValueInit("builders"));
   if (builderList && !builderList->empty()) {
-    for (llvm::Init *init : builderList->getValues())
+    for (const llvm::Init *init : builderList->getValues())
       builders.emplace_back(cast<llvm::DefInit>(init)->getDef(), def.getLoc());
   } else if (skipDefaultBuilders()) {
     PrintFatalError(
@@ -818,7 +818,8 @@ bool Operator::hasAssemblyFormat() const {
 }
 
 StringRef Operator::getAssemblyFormat() const {
-  return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
+  return TypeSwitch<const llvm::Init *, StringRef>(
+             def.getValueInit("assemblyFormat"))
       .Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
 }
 
@@ -832,7 +833,7 @@ void Operator::print(llvm::raw_ostream &os) const {
   }
 }
 
-auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init)
+auto Operator::VariableDecoratorIterator::unwrap(const llvm::Init *init)
     -> VariableDecorator {
   return VariableDecorator(cast<llvm::DefInit>(init)->getDef());
 }
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 6437839ef20849..bee20354387fd6 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -700,7 +700,7 @@ int Pattern::getBenefit() const {
   // The initial benefit value is a heuristic with number of ops in the source
   // pattern.
   int initBenefit = getSourcePattern().getNumOps();
-  llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
+  const llvm::DagInit *delta = def.getValueAsDag("benefitDelta");
   if (delta->getNumArgs() != 1 || !isa<llvm::IntInit>(delta->getArg(0))) {
     PrintFatalError(&def,
                     "The 'addBenefit' takes and only takes one integer value");
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index cda752297988bb..c3b813ec598d0a 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -50,7 +50,7 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
   const llvm::RecordVal *builderCall = baseType->getValue("builderCall");
   if (!builderCall || !builderCall->getValue())
     return std::nullopt;
-  return TypeSwitch<llvm::Init *, std::optional<StringRef>>(
+  return TypeSwitch<const llvm::Init *, std::optional<StringRef>>(
              builderCall->getValue())
       .Case<llvm::StringInit>([&](auto *init) {
         StringRef value = init->getValue();
diff --git a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
index 7119324dd125d5..20ad4292a548bf 100644
--- a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
+++ b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp
@@ -30,8 +30,8 @@ enum DeprecatedAction { None, Warn, Error };
 static DeprecatedAction actionOnDeprecatedValue;
 
 // Returns if there is a use of `deprecatedInit` in `field`.
-static bool findUse(Init *field, Init *deprecatedInit,
-                    llvm::DenseMap<Init *, bool> &known) {
+static bool findUse(const Init *field, const Init *deprecatedInit,
+                    llvm::DenseMap<const Init *, bool> &known) {
   if (field == deprecatedInit)
     return true;
 
@@ -64,13 +64,13 @@ static bool findUse(Init *field, Init *deprecatedInit,
     if (findUse(dagInit->getOperator(), deprecatedInit, known))
       return memoize(true);
 
-    return memoize(llvm::any_of(dagInit->getArgs(), [&](Init *arg) {
+    return memoize(llvm::any_of(dagInit->getArgs(), [&](const Init *arg) {
       return findUse(arg, deprecatedInit, known);
     }));
   }
 
-  if (ListInit *li = dyn_cast<ListInit>(field)) {
-    return memoize(llvm::any_of(li->getValues(), [&](Init *jt) {
+  if (const ListInit *li = dyn_cast<ListInit>(field)) {
+    return memoize(llvm::any_of(li->getValues(), [&](const Init *jt) {
       return findUse(jt, deprecatedInit, known);
     }));
   }
@@ -83,8 +83,8 @@ static bool findUse(Init *field, Init *deprecatedInit,
 }
 
 // Returns if there is a use of `deprecatedInit` in `record`.
-static bool findUse(Record &record, Init *deprecatedInit,
-                    llvm::DenseMap<Init *, bool> &known) {
+static bool findUse(Record &record, const Init *deprecatedInit,
+                    llvm::DenseMap<const Init *, bool> &known) {
   return llvm::any_of(record.getValues(), [&](const RecordVal &val) {
     return findUse(val.getValue(), deprecatedInit, known);
   });
@@ -100,7 +100,7 @@ static void warnOfDeprecatedUses(const RecordKeeper &records) {
     if (!r || !r->getValue())
       continue;
 
-    llvm::DenseMap<Init *, bool> hasUse;
+    llvm::DenseMap<const Init *, bool> hasUse;
     if (auto *si = dyn_cast<StringInit>(r->getValue())) {
       for (auto &jt : records.getDefs()) {
         // Skip anonymous defs.
diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
index 86ebaf2cf27dfe..6a3d5a25e28cd9 100644
--- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
@@ -46,8 +46,9 @@ class Generator {
 private:
   /// Emits parse calls to construct given kind.
   void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
-                       ArrayRef<Init *> args, ArrayRef<std::string> argNames,
-                       StringRef failure, mlir::raw_indented_ostream &ios);
+                       ArrayRef<const Init *> args,
+                       ArrayRef<std::string> argNames, StringRef failure,
+                       mlir::raw_indented_ostream &ios);
 
   /// Emits print instructions.
   void emitPrintHelper(const Record *memberRec, StringRef kind,
@@ -135,10 +136,12 @@ void Generator::emitParse(StringRef kind, const Record &x) {
       R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
   mlir::raw_indented_ostream os(output);
   std::string returnType = getCType(&x);
-  os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName());
-  DagInit *members = x.getValueAsDag("members");
-  SmallVector<std::string> argNames =
-      llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
+  os << formatv(head,
+                kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type",
+                x.getName());
+  const DagInit *members = x.getValueAsDag("members");
+  SmallVector<std::string> argNames = llvm::to_vector(
+      map_range(members->getArgNames(), [](const StringInit *init) {
         return init->getAsUnquotedString();
       }));
   StringRef builder = x.getValueAsString("cBuilder").trim();
@@ -148,7 +151,7 @@ void Generator::emitParse(StringRef kind, const Record &x) {
 }
 
 void printParseConditional(mlir::raw_indented_ostream &ios,
-                           ArrayRef<Init *> args,
+                           ArrayRef<const Init *> args,
                            ArrayRef<std::string> argNames) {
   ios << "if ";
   auto parenScope = ios.scope("(", ") {");
@@ -159,7 +162,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
   };
 
   auto parsedArgs =
-      llvm::to_vector(make_filter_range(args, [](Init *const attr) {
+      llvm::to_vector(make_filter_range(args, [](const Init *const attr) {
         const Record *def = cast<DefInit>(attr)->getDef();
         if (def->isSubClassOf("Array"))
           return true;
@@ -168,7 +171,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
 
   interleave(
       zip(parsedArgs, argNames),
-      [&](std::tuple<llvm::Init *&, const std::string &> it) {
+      [&](std::tuple<const Init *&, const std::string &> it) {
         const Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
         std::string parser;
         if (auto optParser = attr->getValueAsOptionalString("cParser")) {
@@ -196,7 +199,7 @@ void printParseConditional(mlir::raw_indented_ostream &ios,
 }
 
 void Generator::emitParseHelper(StringRef kind, StringRef returnType,
-                                StringRef builder, ArrayRef<Init *> args,
+                                StringRef builder, ArrayRef<const Init *> args,
                                 ArrayRef<std::string> argNames,
                                 StringRef failure,
                                 mlir::raw_indented_ostream &ios) {
@@ -210,7 +213,7 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
   // Print decls.
   std::string lastCType = "";
   for (auto [arg, name] : zip(args, argNames)) {
-    DefInit *first = dyn_cast<DefInit>(arg);
+    const DefInit *first = dyn_cast<DefInit>(arg);
     if (!first)
       PrintFatalError("Unexpected type for " + name);
     const Record *def = first->getDef();
@@ -251,13 +254,14 @@ void Generator::emitParseHelper(StringRef kind, StringRef returnType,
     std::string returnType = getCType(def);
     ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
         << returnType << "> ";
-    SmallVector<Init *> args;
+    SmallVector<const Init *> args;
     SmallVector<std::string> argNames;
     if (def->isSubClassOf("CompositeBytecode")) {
-      DagInit *members = def->getValueAsDag("members");
-      args = llvm::to_vector(members->getArgs());
+      const DagInit *members = def->getValueAsDag("members");
+      args = llvm::to_vector(map_range(
+          members->getArgs(), [](Init *init) { return (const Init *)init; }));
       argNames = llvm::to_vector(
-          map_range(members->getArgNames(), [](StringInit *init) {
+          map_range(members->getArgNames(), [](const StringInit *init) {
             return init->getAsUnquotedString();
           }));
     } else {
@@ -332,7 +336,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
     auto *members = rec->getValueAsDag("members");
     for (auto [arg, name] :
          llvm::zip(members->getArgs(), members->getArgNames())) {
-      DefInit *def = dyn_cast<DefInit>(arg);
+      const DefInit *def = dyn_cast<DefInit>(arg);
       assert(def);
       const Record *memberRec = def->getDef();
       emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
@@ -385,7 +389,7 @@ void Generator::emitPrintHelper(const Record *memberRec, StringRef kind,
     auto *members = memberRec->getValueAsDag("members");
     for (auto [arg, argName] :
          zip(members->getArgs(), members->getArgNames())) {
-      DefInit *def = dyn_cast<DefInit>(arg);
+      const DefInit *def = dyn_cast<DefInit>(arg);
       assert(def);
       emitPrintHelper(def->getDef(), kind, parent,
                       argName->getAsUnquotedString(), ios);
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 55c3d9da259005..414cad5e1dcc2e 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -46,10 +46,10 @@ using DialectFilterIterator =
 } // namespace
 
 static void populateDiscardableAttributes(
-    Dialect &dialect, llvm::DagInit *discardableAttrDag,
+    Dialect &dialect, const llvm::DagInit *discardableAttrDag,
     SmallVector<std::pair<std::string, std::string>> &discardableAttributes) {
   for (int i : llvm::seq<int>(0, discardableAttrDag->getNumArgs())) {
-    llvm::Init *arg = discardableAttrDag->getArg(i);
+    const llvm::Init *arg = discardableAttrDag->getArg(i);
 
     StringRef givenName = discardableAttrDag->getArgNameStr(i);
     if (givenName.empty())
@@ -271,7 +271,8 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
     if (dialect.hasOperationInterfaceFallback())
       os << operationInterfaceFallbackDecl;
 
-    llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+    const llvm::DagInit *discardableAttrDag =
+        dialect.getDiscardableAttributes();
     SmallVector<std::pair<std::string, std::string>> discardableAttributes;
     populateDiscardableAttributes(dialect, discardableAttrDag,
                                   discardableAttributes);
@@ -370,7 +371,7 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records,
   StringRef superClassName =
       dialect.isExtensible() ? "ExtensibleDialect" : "Dialect";
 
-  llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
+  const llvm::DagInit *discardableAttrDag = dialect.getDiscardableAttributes();
   SmallVector<std::pair<std::string, std::string>> discardableAttributes;
   populateDiscardableAttributes(dialect, discardableAttrDag,
                                 discardableAttributes);
diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
index 1c20a6a9bcf4e8..8716667723a373 100644
--- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp
+++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp
@@ -102,11 +102,13 @@ static StringRef extractOmpClauseName(const Record *clause) {
 
 /// Check that the given argument, identified by its name and initialization
 /// value, is present in the \c arguments `dag`.
-static bool verifyArgument(DagInit *arguments, StringRef argName,
-                           Init *argInit) {
+static bool verifyArgument(const DagInit *arguments, StringRef argName,
+                           const Init *argInit) {
   auto range = zip_equal(arguments->getArgNames(), arguments->getArgs());
   return llvm::any_of(
-      range, [&](std::tuple<llvm::StringInit *const &, llvm::Init *const &> v) {
+      range,
+      [&](std::tuple<const llvm::StringInit *const &, const llvm::Init *const &>
+              v) {
         return std::get<0>(v)->getAsUnquotedString() == argName &&
                std::get<1>(v) == argInit;
       });
@@ -141,8 +143,8 @@ static void verifyClause(const Record *op, const Record *clause) {
   StringRef clauseClassName = extractOmpClauseName(clause);
 
   if (!clause->getValueAsBit("ignoreArgs")) {
-    DagInit *opArguments = op->getValueAsDag("arguments");
-    DagInit *arguments = clause->getValueAsDag("arguments");
+    const DagInit *opArguments = op->getValueAsDag("arguments");
+    const DagInit *arguments = clause->getValueAsDag("arguments");
 
     for (auto [name, arg] :
          zip(arguments->getArgNames(), arguments->getArgs())) {
@@ -208,8 +210,9 @@ static void verifyClause(const Record *op, const Record *clause) {
 ///
 /// \return the name of the base type to represent elements of the argument
 ///         type.
-static StringRef translateArgumentType(ArrayRef<SMLoc> loc, StringInit *name,
-                                       Init *init, int &nest, int &rank) {
+static StringRef translateArgumentType(ArrayRef<SMLoc> loc,
+                                       const StringInit *name, const Init *init,
+                                       int &nest, int &rank) {
   const Record *def = cast<DefInit>(init)->getDef();
 
   llvm::StringSet<> superClasses;
@@ -282,7 +285,7 @@ static void genClauseOpsStruct(const Record *clause, raw_ostream &os) {
   StringRef clauseName = extractOmpClauseName(clause);
   os << "struct " << clauseName << "ClauseOps {\n";
 
-  DagInit *arguments = clause->getValueAsDag("arguments");
+  const DagInit *arguments = clause->getValueAsDag("arguments");
   for (auto [name, arg] :
        zip_equal(arguments->getArgNames(), arguments->getArgs())) {
     int nest = 0, rank = 1;



More information about the Mlir-commits mailing list