[Mlir-commits] [mlir] [mlir][NFC] Move and rename EnumAttrCase, EnumAttr C++ classes (PR #132650)

Krzysztof Drewniak llvmlistbot at llvm.org
Sun Mar 23 17:11:11 PDT 2025


https://github.com/krzysz00 created https://github.com/llvm/llvm-project/pull/132650

This moves the EnumAttrCase and EnumAttr classes from Attribute.h/.cpp to a new EnumInfo.h/cpp and renames them to EnumCase and EnumInfo, respectively.

This doesn't change any of the tablegen files or any user-facing aspects of the enum attribute generation system, just reorganizes code in order to make main PR (#132148) shorter.

>From d130635269f41fa489fb51ea1aae8d2067c4ccca Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Sun, 23 Mar 2025 17:01:13 -0700
Subject: [PATCH] [mlir][NFC] Move and rename EnumAttrCas, EnumAttr C++ classes

This moves the EnumAttrCase and EnumAttr classes from Attribute.h/.cpp
to a new EnumInfo.h/cpp and renames them to EnumCase and EnumInfo,
respectively.

This doesn't change any of the tablegen files or any user-facing
aspects of the enum attribute generation system, just reorganizes code
in order to make main PR (#132148) shorter.
---
 mlir/include/mlir/TableGen/Attribute.h        |  74 --------
 mlir/include/mlir/TableGen/EnumInfo.h         | 135 ++++++++++++++
 mlir/include/mlir/TableGen/Pattern.h          |  11 +-
 mlir/lib/TableGen/Attribute.cpp               |  94 ----------
 mlir/lib/TableGen/CMakeLists.txt              |   1 +
 mlir/lib/TableGen/EnumInfo.cpp                | 130 ++++++++++++++
 mlir/lib/TableGen/Pattern.cpp                 |  12 +-
 .../mlir-tblgen/EnumPythonBindingGen.cpp      |  47 ++---
 mlir/tools/mlir-tblgen/EnumsGen.cpp           | 166 +++++++++---------
 .../tools/mlir-tblgen/LLVMIRConversionGen.cpp |  75 ++++----
 mlir/tools/mlir-tblgen/OpDocGen.cpp           |  17 +-
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        |  35 ++--
 mlir/tools/mlir-tblgen/RewriterGen.cpp        |   8 +-
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |  91 +++++-----
 mlir/tools/mlir-tblgen/TosaUtilsGen.cpp       |   5 +-
 15 files changed, 506 insertions(+), 395 deletions(-)
 create mode 100644 mlir/include/mlir/TableGen/EnumInfo.h
 create mode 100644 mlir/lib/TableGen/EnumInfo.cpp

diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 62720e74849fc..dee81880bacab 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -16,7 +16,6 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Constraint.h"
-#include "llvm/ADT/StringRef.h"
 
 namespace llvm {
 class DefInit;
@@ -136,79 +135,6 @@ class ConstantAttr {
   const llvm::Record *def;
 };
 
-// Wrapper class providing helper methods for accessing enum attribute cases
-// defined in TableGen. This is used for enum attribute case backed by both
-// StringAttr and IntegerAttr.
-class EnumAttrCase : public Attribute {
-public:
-  explicit EnumAttrCase(const llvm::Record *record);
-  explicit EnumAttrCase(const llvm::DefInit *init);
-
-  // Returns the symbol of this enum attribute case.
-  StringRef getSymbol() const;
-
-  // Returns the textual representation of this enum attribute case.
-  StringRef getStr() const;
-
-  // Returns the value of this enum attribute case.
-  int64_t getValue() const;
-
-  // Returns the TableGen definition this EnumAttrCase was constructed from.
-  const llvm::Record &getDef() const;
-};
-
-// Wrapper class providing helper methods for accessing enum attributes defined
-// in TableGen.This is used for enum attribute case backed by both StringAttr
-// and IntegerAttr.
-class EnumAttr : public Attribute {
-public:
-  explicit EnumAttr(const llvm::Record *record);
-  explicit EnumAttr(const llvm::Record &record);
-  explicit EnumAttr(const llvm::DefInit *init);
-
-  static bool classof(const Attribute *attr);
-
-  // Returns true if this is a bit enum attribute.
-  bool isBitEnum() const;
-
-  // Returns the enum class name.
-  StringRef getEnumClassName() const;
-
-  // Returns the C++ namespaces this enum class should be placed in.
-  StringRef getCppNamespace() const;
-
-  // Returns the underlying type.
-  StringRef getUnderlyingType() const;
-
-  // Returns the name of the utility function that converts a value of the
-  // underlying type to the corresponding symbol.
-  StringRef getUnderlyingToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a string to the
-  // corresponding symbol.
-  StringRef getStringToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a symbol to the
-  // corresponding string.
-  StringRef getSymbolToStringFnName() const;
-
-  // Returns the return type of the utility function that converts a symbol to
-  // the corresponding string.
-  StringRef getSymbolToStringFnRetType() const;
-
-  // Returns the name of the utilit function that returns the max enum value
-  // used within the enum class.
-  StringRef getMaxEnumValFnName() const;
-
-  // Returns all allowed cases for this enum attribute.
-  std::vector<EnumAttrCase> getAllCases() const;
-
-  bool genSpecializedAttr() const;
-  const llvm::Record *getBaseAttrClass() const;
-  StringRef getSpecializedAttrClassName() const;
-  bool printBitEnumPrimaryGroups() const;
-};
-
 // Name of infer type op interface.
 extern const char *inferTypeOpInterface;
 
diff --git a/mlir/include/mlir/TableGen/EnumInfo.h b/mlir/include/mlir/TableGen/EnumInfo.h
new file mode 100644
index 0000000000000..196267864f325
--- /dev/null
+++ b/mlir/include/mlir/TableGen/EnumInfo.h
@@ -0,0 +1,135 @@
+//===- EnumInfo.h - EnumInfo wrapper class --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// EnumInfo wrapper to simplify using a TableGen Record defining an Enum
+// via EnumInfo and its `EnumCase`s.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_ENUMINFO_H_
+#define MLIR_TABLEGEN_ENUMINFO_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class DefInit;
+class Record;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class providing around enum cases defined in TableGen.
+class EnumCase {
+public:
+  explicit EnumCase(const llvm::Record *record);
+  explicit EnumCase(const llvm::DefInit *init);
+
+  // Returns the symbol of this enum attribute case.
+  StringRef getSymbol() const;
+
+  // Returns the textual representation of this enum attribute case.
+  StringRef getStr() const;
+
+  // Returns the value of this enum attribute case.
+  int64_t getValue() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+// Wrapper class providing helper methods for accessing enums defined
+// in TableGen using EnumInfo. Some methods are only applicable when
+// the enum is also an attribute, or only when it is a bit enum.
+class EnumInfo {
+public:
+  explicit EnumInfo(const llvm::Record *record);
+  explicit EnumInfo(const llvm::Record &record);
+  explicit EnumInfo(const llvm::DefInit *init);
+
+  // Returns true if the given EnumInfo is a subclass of the named TableGen
+  // class.
+  bool isSubClassOf(StringRef className) const;
+
+  // Returns true if this enum is an EnumAttrInfo, thus making it define an
+  // attribute.
+  bool isEnumAttr() const;
+
+  // Create the `Attribute` wrapper around this EnumInfo if it is defining an
+  // attribute.
+  std::optional<Attribute> asEnumAttr() const;
+
+  // Returns true if this is a bit enum.
+  bool isBitEnum() const;
+
+  // Returns the enum class name.
+  StringRef getEnumClassName() const;
+
+  // Returns the C++ namespaces this enum class should be placed in.
+  StringRef getCppNamespace() const;
+
+  // Returns the summary of the enum.
+  StringRef getSummary() const;
+
+  // Returns the description of the enum.
+  StringRef getDescription() const;
+
+  // Returns the underlying type.
+  StringRef getUnderlyingType() const;
+
+  // Returns the name of the utility function that converts a value of the
+  // underlying type to the corresponding symbol.
+  StringRef getUnderlyingToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a string to the
+  // corresponding symbol.
+  StringRef getStringToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a symbol to the
+  // corresponding string.
+  StringRef getSymbolToStringFnName() const;
+
+  // Returns the return type of the utility function that converts a symbol to
+  // the corresponding string.
+  StringRef getSymbolToStringFnRetType() const;
+
+  // Returns the name of the utilit function that returns the max enum value
+  // used within the enum class.
+  StringRef getMaxEnumValFnName() const;
+
+  // Returns all allowed cases for this enum attribute.
+  std::vector<EnumCase> getAllCases() const;
+
+  // Only applicable for enum attributes.
+
+  bool genSpecializedAttr() const;
+  const llvm::Record *getBaseAttrClass() const;
+  StringRef getSpecializedAttrClassName() const;
+
+  // Only applicable for bit enums.
+
+  bool printBitEnumPrimaryGroups() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 80f38fdeffee0..1c9e128f0a0fb 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Hashing.h"
@@ -78,8 +79,8 @@ class DagLeaf {
   // Returns true if this DAG leaf is specifying a constant attribute.
   bool isConstantAttr() const;
 
-  // Returns true if this DAG leaf is specifying an enum attribute case.
-  bool isEnumAttrCase() const;
+  // Returns true if this DAG leaf is specifying an enum case.
+  bool isEnumCase() const;
 
   // Returns true if this DAG leaf is specifying a string attribute.
   bool isStringAttr() const;
@@ -90,9 +91,9 @@ class DagLeaf {
   // Returns this DAG leaf as an constant attribute. Asserts if fails.
   ConstantAttr getAsConstantAttr() const;
 
-  // Returns this DAG leaf as an enum attribute case.
-  // Precondition: isEnumAttrCase()
-  EnumAttrCase getAsEnumAttrCase() const;
+  // Returns this DAG leaf as an enum case.
+  // Precondition: isEnumCase()
+  EnumCase getAsEnumCase() const;
 
   // Returns the matching condition template inside this DAG leaf. Assumes the
   // leaf is an operand/attribute matcher and asserts otherwise.
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index f9fc58a40f334..142d194260942 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -146,98 +146,4 @@ StringRef ConstantAttr::getConstantValue() const {
   return def->getValueAsString("value");
 }
 
-EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrCaseInfo") &&
-         "must be subclass of TableGen 'EnumAttrInfo' class");
-}
-
-EnumAttrCase::EnumAttrCase(const DefInit *init)
-    : EnumAttrCase(init->getDef()) {}
-
-StringRef EnumAttrCase::getSymbol() const {
-  return def->getValueAsString("symbol");
-}
-
-StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
-
-int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
-
-const Record &EnumAttrCase::getDef() const { return *def; }
-
-EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrInfo") &&
-         "must be subclass of TableGen 'EnumAttr' class");
-}
-
-EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}
-
-EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}
-
-bool EnumAttr::classof(const Attribute *attr) {
-  return attr->isSubClassOf("EnumAttrInfo");
-}
-
-bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
-
-StringRef EnumAttr::getEnumClassName() const {
-  return def->getValueAsString("className");
-}
-
-StringRef EnumAttr::getCppNamespace() const {
-  return def->getValueAsString("cppNamespace");
-}
-
-StringRef EnumAttr::getUnderlyingType() const {
-  return def->getValueAsString("underlyingType");
-}
-
-StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
-  return def->getValueAsString("underlyingToSymbolFnName");
-}
-
-StringRef EnumAttr::getStringToSymbolFnName() const {
-  return def->getValueAsString("stringToSymbolFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnName() const {
-  return def->getValueAsString("symbolToStringFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnRetType() const {
-  return def->getValueAsString("symbolToStringFnRetType");
-}
-
-StringRef EnumAttr::getMaxEnumValFnName() const {
-  return def->getValueAsString("maxEnumValFnName");
-}
-
-std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
-  const auto *inits = def->getValueAsListInit("enumerants");
-
-  std::vector<EnumAttrCase> cases;
-  cases.reserve(inits->size());
-
-  for (const Init *init : *inits) {
-    cases.emplace_back(cast<DefInit>(init));
-  }
-
-  return cases;
-}
-
-bool EnumAttr::genSpecializedAttr() const {
-  return def->getValueAsBit("genSpecializedAttr");
-}
-
-const Record *EnumAttr::getBaseAttrClass() const {
-  return def->getValueAsDef("baseAttrClass");
-}
-
-StringRef EnumAttr::getSpecializedAttrClassName() const {
-  return def->getValueAsString("specializedAttrClassName");
-}
-
-bool EnumAttr::printBitEnumPrimaryGroups() const {
-  return def->getValueAsBit("printBitEnumPrimaryGroups");
-}
-
 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index c4104e644147c..a90c55847718e 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -20,6 +20,7 @@ llvm_add_library(MLIRTableGen STATIC
   CodeGenHelpers.cpp
   Constraint.cpp
   Dialect.cpp
+  EnumInfo.cpp
   Format.cpp
   GenInfo.cpp
   Interfaces.cpp
diff --git a/mlir/lib/TableGen/EnumInfo.cpp b/mlir/lib/TableGen/EnumInfo.cpp
new file mode 100644
index 0000000000000..9f491d30f0e7f
--- /dev/null
+++ b/mlir/lib/TableGen/EnumInfo.cpp
@@ -0,0 +1,130 @@
+//===- EnumInfo.cpp - EnumInfo wrapper class ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/EnumInfo.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+using llvm::DefInit;
+using llvm::Init;
+using llvm::Record;
+
+EnumCase::EnumCase(const Record *record) : def(record) {
+  assert(def->isSubClassOf("EnumAttrCaseInfo") &&
+         "must be subclass of TableGen 'EnumAttrCaseInfo' class");
+}
+
+EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {}
+
+StringRef EnumCase::getSymbol() const {
+  return def->getValueAsString("symbol");
+}
+
+StringRef EnumCase::getStr() const { return def->getValueAsString("str"); }
+
+int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); }
+
+const Record &EnumCase::getDef() const { return *def; }
+
+EnumInfo::EnumInfo(const Record *record) : def(record) {
+  assert(isSubClassOf("EnumAttrInfo") &&
+         "must be subclass of TableGen 'EnumAttrInfo' class");
+}
+
+EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {}
+
+EnumInfo::EnumInfo(const DefInit *init) : EnumInfo(init->getDef()) {}
+
+bool EnumInfo::isSubClassOf(StringRef className) const {
+  return def->isSubClassOf(className);
+}
+
+bool EnumInfo::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
+
+std::optional<Attribute> EnumInfo::asEnumAttr() const {
+  if (isEnumAttr())
+    return Attribute(def);
+  return std::nullopt;
+}
+
+bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
+
+StringRef EnumInfo::getEnumClassName() const {
+  return def->getValueAsString("className");
+}
+
+StringRef EnumInfo::getSummary() const {
+  return def->getValueAsString("summary");
+}
+
+StringRef EnumInfo::getDescription() const {
+  return def->getValueAsString("description");
+}
+
+StringRef EnumInfo::getCppNamespace() const {
+  return def->getValueAsString("cppNamespace");
+}
+
+StringRef EnumInfo::getUnderlyingType() const {
+  return def->getValueAsString("underlyingType");
+}
+
+StringRef EnumInfo::getUnderlyingToSymbolFnName() const {
+  return def->getValueAsString("underlyingToSymbolFnName");
+}
+
+StringRef EnumInfo::getStringToSymbolFnName() const {
+  return def->getValueAsString("stringToSymbolFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnName() const {
+  return def->getValueAsString("symbolToStringFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnRetType() const {
+  return def->getValueAsString("symbolToStringFnRetType");
+}
+
+StringRef EnumInfo::getMaxEnumValFnName() const {
+  return def->getValueAsString("maxEnumValFnName");
+}
+
+std::vector<EnumCase> EnumInfo::getAllCases() const {
+  const auto *inits = def->getValueAsListInit("enumerants");
+
+  std::vector<EnumCase> cases;
+  cases.reserve(inits->size());
+
+  for (const Init *init : *inits) {
+    cases.emplace_back(cast<DefInit>(init));
+  }
+
+  return cases;
+}
+
+bool EnumInfo::genSpecializedAttr() const {
+  return isSubClassOf("EnumAttrInfo") &&
+         def->getValueAsBit("genSpecializedAttr");
+}
+
+const Record *EnumInfo::getBaseAttrClass() const {
+  return def->getValueAsDef("baseAttrClass");
+}
+
+StringRef EnumInfo::getSpecializedAttrClassName() const {
+  return def->getValueAsString("specializedAttrClassName");
+}
+
+bool EnumInfo::printBitEnumPrimaryGroups() const {
+  return def->getValueAsBit("printBitEnumPrimaryGroups");
+}
+
+const Record &EnumInfo::getDef() const { return *def; }
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index ac8c49c72d384..73e2803c21dae 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -57,9 +57,7 @@ bool DagLeaf::isNativeCodeCall() const {
 
 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
 
-bool DagLeaf::isEnumAttrCase() const {
-  return isSubClassOf("EnumAttrCaseInfo");
-}
+bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); }
 
 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
 
@@ -74,9 +72,9 @@ ConstantAttr DagLeaf::getAsConstantAttr() const {
   return ConstantAttr(cast<DefInit>(def));
 }
 
-EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
-  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
-  return EnumAttrCase(cast<DefInit>(def));
+EnumCase DagLeaf::getAsEnumCase() const {
+  assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
+  return EnumCase(cast<DefInit>(def));
 }
 
 std::string DagLeaf::getConditionTemplate() const {
@@ -776,7 +774,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
         } else {
           auto constraint = leaf.getAsConstraint();
-          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
+          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
                         leaf.isConstantAttr() ||
                         constraint.getKind() == Constraint::Kind::CK_Attr;
 
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 3f660ae151c74..5d4d9e90fff67 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -15,6 +15,7 @@
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
@@ -44,14 +45,14 @@ static std::string makePythonEnumCaseName(StringRef name) {
 }
 
 /// Emits the Python class for the given enum.
-static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
-  os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
-                enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
-  if (!enumAttr.getSummary().empty())
-    os << formatv("    \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
+static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
+  os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(),
+                enumInfo.isBitEnum() ? "IntFlag" : "IntEnum");
+  if (!enumInfo.getSummary().empty())
+    os << formatv("    \"\"\"{0}\"\"\"\n", enumInfo.getSummary());
   os << "\n";
 
-  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
+  for (const EnumCase &enumCase : enumInfo.getAllCases()) {
     os << formatv("    {0} = {1}\n",
                   makePythonEnumCaseName(enumCase.getSymbol()),
                   enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
@@ -60,7 +61,7 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
 
   os << "\n";
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     os << formatv("    def __iter__(self):\n"
                   "        return iter([case for case in type(self) if "
                   "(self & case) is case])\n");
@@ -70,17 +71,17 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
   }
 
   os << formatv("    def __str__(self):\n");
-  if (enumAttr.isBitEnum())
+  if (enumInfo.isBitEnum())
     os << formatv("        if len(self) > 1:\n"
                   "            return \"{0}\".join(map(str, self))\n",
-                  enumAttr.getDef().getValueAsString("separator"));
-  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
-    os << formatv("        if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
+                  enumInfo.getDef().getValueAsString("separator"));
+  for (const EnumCase &enumCase : enumInfo.getAllCases()) {
+    os << formatv("        if self is {0}.{1}:\n", enumInfo.getEnumClassName(),
                   makePythonEnumCaseName(enumCase.getSymbol()));
     os << formatv("            return \"{0}\"\n", enumCase.getStr());
   }
   os << formatv("        raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
-                enumAttr.getEnumClassName());
+                enumInfo.getEnumClassName());
   os << "\n";
 }
 
@@ -98,17 +99,21 @@ static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
 /// Emits an attribute builder for the given enum attribute to support automatic
 /// conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
+static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
+  std::optional<Attribute> enumAttrInfo = enumInfo.asEnumAttr();
+  if (!enumAttrInfo)
+    return false;
+
   int64_t bitwidth;
-  if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
+  if (extractUIntBitwidth(enumInfo.getUnderlyingType(), bitwidth)) {
     llvm::errs() << "failed to identify bitwidth of "
-                 << enumAttr.getUnderlyingType();
+                 << enumInfo.getUnderlyingType();
     return true;
   }
-
   os << formatv("@register_attribute_builder(\"{0}\")\n",
-                enumAttr.getAttrDefName());
-  os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
+                enumAttrInfo->getAttrDefName());
+  os << formatv("def _{0}(x, context):\n",
+                enumAttrInfo->getAttrDefName().lower());
   os << formatv("    return "
                 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
                 "context=context), int(x))\n\n",
@@ -136,9 +141,9 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
   os << fileHeader;
   for (const Record *it :
        records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
-    EnumAttr enumAttr(*it);
-    emitEnumClass(enumAttr, os);
-    emitAttributeBuilder(enumAttr, os);
+    EnumInfo enumInfo(*it);
+    emitEnumClass(enumInfo, os);
+    emitAttributeBuilder(enumInfo, os);
   }
   for (const Record *it :
        records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index d11aa9b27c2d8..fa6fad156b747 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -12,6 +12,7 @@
 
 #include "FormatGen.h"
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "llvm/ADT/BitVector.h"
@@ -30,8 +31,8 @@ using llvm::Record;
 using llvm::RecordKeeper;
 using namespace mlir;
 using mlir::tblgen::Attribute;
-using mlir::tblgen::EnumAttr;
-using mlir::tblgen::EnumAttrCase;
+using mlir::tblgen::EnumCase;
+using mlir::tblgen::EnumInfo;
 using mlir::tblgen::FmtContext;
 using mlir::tblgen::tgfmt;
 
@@ -45,7 +46,7 @@ static std::string makeIdentifier(StringRef str) {
 
 static void emitEnumClass(const Record &enumDef, StringRef enumName,
                           StringRef underlyingType, StringRef description,
-                          const std::vector<EnumAttrCase> &enumerants,
+                          const std::vector<EnumCase> &enumerants,
                           raw_ostream &os) {
   os << "// " << description << "\n";
   os << "enum class " << enumName;
@@ -66,12 +67,13 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName,
   os << "};\n\n";
 }
 
-static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
+static void emitParserPrinter(const EnumInfo &enumInfo, StringRef qualName,
                               StringRef cppNamespace, raw_ostream &os) {
-  if (enumAttr.getUnderlyingType().empty() ||
-      enumAttr.getConstBuilderTemplate().empty())
+  std::optional<Attribute> enumAttrInfo = enumInfo.asEnumAttr();
+  if (enumInfo.getUnderlyingType().empty() ||
+      (enumAttrInfo && enumAttrInfo->getConstBuilderTemplate().empty()))
     return;
-  auto cases = enumAttr.getAllCases();
+  auto cases = enumInfo.getAllCases();
 
   // Check which cases shouldn't be printed using a keyword.
   llvm::BitVector nonKeywordCases(cases.size());
@@ -128,8 +130,9 @@ namespace llvm {
 inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
   auto valueStr = stringifyEnum(value);
 )";
+
   os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
-                enumAttr.getSummary());
+                enumInfo.getSummary());
 
   // If all cases require a string, always wrap.
   if (nonKeywordCases.all()) {
@@ -157,9 +160,9 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
 
     // If this is a bit enum, conservatively print the string form if the value
     // is not a power of two (i.e. not a single bit case) and not a known case.
-  } else if (enumAttr.isBitEnum()) {
+  } else if (enumInfo.isBitEnum()) {
     // Process the known multi-bit cases that use valid keywords.
-    SmallVector<EnumAttrCase *> validMultiBitCases;
+    SmallVector<EnumCase *> validMultiBitCases;
     for (auto [index, caseVal] : llvm::enumerate(cases)) {
       uint64_t value = caseVal.getValue();
       if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
@@ -167,7 +170,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
     }
     if (!validMultiBitCases.empty()) {
       os << "  switch (value) {\n";
-      for (EnumAttrCase *caseVal : validMultiBitCases) {
+      for (EnumCase *caseVal : validMultiBitCases) {
         StringRef symbol = caseVal->getSymbol();
         os << llvm::formatv("  case {0}::{1}:\n", qualName,
                             llvm::isDigit(symbol.front()) ? ("_" + symbol)
@@ -224,9 +227,9 @@ template<> struct DenseMapInfo<{0}> {{
 }
 
 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef maxEnumValFnName = enumInfo.getMaxEnumValFnName();
+  auto enumerants = enumInfo.getAllCases();
 
   unsigned maxEnumVal = 0;
   for (const auto &enumerant : enumerants) {
@@ -245,10 +248,10 @@ static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
   os << "}\n\n";
 }
 
-// Returns the EnumAttrCase whose value is zero if exists; returns std::nullopt
+// Returns the EnumCase whose value is zero if exists; returns std::nullopt
 // otherwise.
-static std::optional<EnumAttrCase>
-getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
+static std::optional<EnumCase>
+getAllBitsUnsetCase(llvm::ArrayRef<EnumCase> cases) {
   for (auto attrCase : cases) {
     if (attrCase.getValue() == 0)
       return attrCase;
@@ -268,9 +271,9 @@ getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
 // inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
 // bool value=true);
 static void emitOperators(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
   int64_t validBits = enumDef.getValueAsInt("validBits");
   const char *const operators = R"(
 inline constexpr {0} operator|({0} a, {0} b) {{
@@ -303,11 +306,11 @@ inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true)
 }
 
 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
-  StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef symToStrFnName = enumInfo.getSymbolToStringFnName();
+  StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType();
+  auto enumerants = enumInfo.getAllCases();
 
   os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
                 symToStrFnRetType);
@@ -324,19 +327,19 @@ static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
 }
 
 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
-  StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef symToStrFnName = enumInfo.getSymbolToStringFnName();
+  StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType();
   StringRef separator = enumDef.getValueAsString("separator");
-  auto enumerants = enumAttr.getAllCases();
+  auto enumerants = enumInfo.getAllCases();
   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
 
   os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
                 symToStrFnRetType);
 
   os << formatv("  auto val = static_cast<{0}>(symbol);\n",
-                enumAttr.getUnderlyingType());
+                enumInfo.getUnderlyingType());
   // If we have unknown bit set, return an empty string to signal errors.
   int64_t validBits = enumDef.getValueAsInt("validBits");
   os << formatv("  assert({0}u == ({0}u | val) && \"invalid bits set in bit "
@@ -365,21 +368,23 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
 )";
   // Optionally elide bits that are members of groups that will also be printed
   // for more concise output.
-  if (enumAttr.printBitEnumPrimaryGroups()) {
+  if (enumInfo.printBitEnumPrimaryGroups()) {
     os << "  // Print bit enum groups before individual bits\n";
     // Emit comparisons for group bit cases in reverse tablegen declaration
     // order, removing bits for groups with all bits present.
     for (const auto &enumerant : llvm::reverse(enumerants)) {
       if ((enumerant.getValue() != 0) &&
-          enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) {
+          (enumerant.getDef().isSubClassOf("BitEnumCaseGroup") ||
+           enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup"))) {
         os << formatv(formatCompareRemove, enumerant.getValue(),
-                      enumerant.getStr(), enumAttr.getUnderlyingType());
+                      enumerant.getStr(), enumInfo.getUnderlyingType());
       }
     }
     // Emit comparisons for individual bit cases in tablegen declaration order.
     for (const auto &enumerant : enumerants) {
       if ((enumerant.getValue() != 0) &&
-          enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit"))
+          (enumerant.getDef().isSubClassOf("BitEnumCaseBit") ||
+           enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit")))
         os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
     }
   } else {
@@ -396,10 +401,10 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
 }
 
 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef strToSymFnName = enumInfo.getStringToSymbolFnName();
+  auto enumerants = enumInfo.getAllCases();
 
   os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
                 enumName, strToSymFnName);
@@ -416,13 +421,13 @@ static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
 }
 
 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  StringRef strToSymFnName = enumInfo.getStringToSymbolFnName();
   StringRef separator = enumDef.getValueAsString("separator");
   StringRef separatorTrimmed = separator.trim();
-  auto enumerants = enumAttr.getAllCases();
+  auto enumerants = enumInfo.getAllCases();
   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
 
   os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
@@ -463,17 +468,16 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
 
 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
                                             raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
+  auto enumerants = enumInfo.getAllCases();
 
   // Avoid generating the underlying value to symbol conversion function if
   // there is an enumerant without explicit value.
-  if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
-        return enumerant.getValue() < 0;
-      }))
+  if (llvm::any_of(enumerants,
+                   [](EnumCase enumerant) { return enumerant.getValue() < 0; }))
     return;
 
   os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
@@ -493,10 +497,10 @@ static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
 }
 
 static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
-  const Record *baseAttrDef = enumAttr.getBaseAttrClass();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef attrClassName = enumInfo.getSpecializedAttrClassName();
+  const Record *baseAttrDef = enumInfo.getBaseAttrClass();
   Attribute baseAttr(baseAttrDef);
 
   // Emit classof method
@@ -520,7 +524,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
   os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
                 attrClassName, enumName);
 
-  StringRef underlyingType = enumAttr.getUnderlyingType();
+  StringRef underlyingType = enumInfo.getUnderlyingType();
 
   // Assuming that it is IntegerAttr constraint
   int64_t bitwidth = 64;
@@ -552,11 +556,11 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
 
 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
                                             raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
+  auto enumerants = enumInfo.getAllCases();
   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
 
   os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
@@ -574,16 +578,16 @@ static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
 }
 
 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  StringRef description = enumAttr.getSummary();
-  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
-  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
-  StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
-  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  StringRef description = enumInfo.getSummary();
+  StringRef strToSymFnName = enumInfo.getStringToSymbolFnName();
+  StringRef symToStrFnName = enumInfo.getSymbolToStringFnName();
+  StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType();
+  StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
+  auto enumerants = enumInfo.getAllCases();
 
   SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
@@ -595,7 +599,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
   emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
 
   // Emit conversion function declarations
-  if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
+  if (llvm::all_of(enumerants, [](EnumCase enumerant) {
         return enumerant.getValue() >= 0;
       })) {
     os << formatv(
@@ -606,7 +610,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
   os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName,
                 strToSymFnName);
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     emitOperators(enumDef, os);
   } else {
     emitMaxValueFn(enumDef, os);
@@ -644,8 +648,8 @@ class {1} : public ::mlir::{2} {
   {0} getValue() const;
 };
 )";
-  if (enumAttr.genSpecializedAttr()) {
-    StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
+  if (enumInfo.genSpecializedAttr()) {
+    StringRef attrClassName = enumInfo.getSpecializedAttrClassName();
     StringRef baseAttrClassName = "IntegerAttr";
     os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
   }
@@ -656,7 +660,7 @@ class {1} : public ::mlir::{2} {
   // Generate a generic parser and printer for the enum.
   std::string qualName =
       std::string(formatv("{0}::{1}", cppNamespace, enumName));
-  emitParserPrinter(enumAttr, qualName, cppNamespace, os);
+  emitParserPrinter(enumInfo, qualName, cppNamespace, os);
 
   // Emit DenseMapInfo for this enum class
   emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
@@ -673,8 +677,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
 }
 
 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  EnumInfo enumInfo(enumDef);
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
@@ -682,7 +686,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
   for (auto ns : namespaces)
     os << "namespace " << ns << " {\n";
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     emitSymToStrFnForBitEnum(enumDef, os);
     emitStrToSymFnForBitEnum(enumDef, os);
     emitUnderlyingToSymFnForBitEnum(enumDef, os);
@@ -692,7 +696,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
     emitUnderlyingToSymFnForIntEnum(enumDef, os);
   }
 
-  if (enumAttr.genSpecializedAttr())
+  if (enumInfo.genSpecializedAttr())
     emitSpecializedAttrDef(enumDef, os);
 
   for (auto ns : llvm::reverse(namespaces))
diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 9e19f479d673a..96af14d36817b 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/TableGen/Argument.h"
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 
@@ -335,13 +336,13 @@ static bool emitOpMLIRBuilders(const RecordKeeper &records, raw_ostream &os) {
 
 namespace {
 // Wrapper class around a Tablegen definition of an LLVM enum attribute case.
-class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
+class LLVMEnumCase : public tblgen::EnumCase {
 public:
-  using tblgen::EnumAttrCase::EnumAttrCase;
+  using tblgen::EnumCase::EnumCase;
 
   // Constructs a case from a non LLVM-specific enum attribute case.
-  explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other)
-      : tblgen::EnumAttrCase(&other.getDef()) {}
+  explicit LLVMEnumCase(const tblgen::EnumCase &other)
+      : tblgen::EnumCase(&other.getDef()) {}
 
   // Returns the C++ enumerant for the LLVM API.
   StringRef getLLVMEnumerant() const {
@@ -350,9 +351,9 @@ class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
 };
 
 // Wraper class around a Tablegen definition of an LLVM enum attribute.
-class LLVMEnumAttr : public tblgen::EnumAttr {
+class LLVMEnumInfo : public tblgen::EnumInfo {
 public:
-  using tblgen::EnumAttr::EnumAttr;
+  using tblgen::EnumInfo::EnumInfo;
 
   // Returns the C++ enum name for the LLVM API.
   StringRef getLLVMClassName() const {
@@ -360,19 +361,19 @@ class LLVMEnumAttr : public tblgen::EnumAttr {
   }
 
   // Returns all associated cases viewed as LLVM-specific enum cases.
-  std::vector<LLVMEnumAttrCase> getAllCases() const {
-    std::vector<LLVMEnumAttrCase> cases;
+  std::vector<LLVMEnumCase> getAllCases() const {
+    std::vector<LLVMEnumCase> cases;
 
-    for (auto &c : tblgen::EnumAttr::getAllCases())
+    for (auto &c : tblgen::EnumInfo::getAllCases())
       cases.emplace_back(c);
 
     return cases;
   }
 
-  std::vector<LLVMEnumAttrCase> getAllUnsupportedCases() const {
+  std::vector<LLVMEnumCase> getAllUnsupportedCases() const {
     const auto *inits = def->getValueAsListInit("unsupported");
 
-    std::vector<LLVMEnumAttrCase> cases;
+    std::vector<LLVMEnumCase> cases;
     cases.reserve(inits->size());
 
     for (const llvm::Init *init : *inits)
@@ -383,9 +384,9 @@ class LLVMEnumAttr : public tblgen::EnumAttr {
 };
 
 // Wraper class around a Tablegen definition of a C-style LLVM enum attribute.
-class LLVMCEnumAttr : public tblgen::EnumAttr {
+class LLVMCEnumInfo : public tblgen::EnumInfo {
 public:
-  using tblgen::EnumAttr::EnumAttr;
+  using tblgen::EnumInfo::EnumInfo;
 
   // Returns the C++ enum name for the LLVM API.
   StringRef getLLVMClassName() const {
@@ -393,10 +394,10 @@ class LLVMCEnumAttr : public tblgen::EnumAttr {
   }
 
   // Returns all associated cases viewed as LLVM-specific enum cases.
-  std::vector<LLVMEnumAttrCase> getAllCases() const {
-    std::vector<LLVMEnumAttrCase> cases;
+  std::vector<LLVMEnumCase> getAllCases() const {
+    std::vector<LLVMEnumCase> cases;
 
-    for (auto &c : tblgen::EnumAttr::getAllCases())
+    for (auto &c : tblgen::EnumInfo::getAllCases())
       cases.emplace_back(c);
 
     return cases;
@@ -408,10 +409,10 @@ class LLVMCEnumAttr : public tblgen::EnumAttr {
 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
 // (Enum) to the corresponding LLVM API enumerant
 static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
-  LLVMEnumAttr enumAttr(record);
-  StringRef llvmClass = enumAttr.getLLVMClassName();
-  StringRef cppClassName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  LLVMEnumInfo enumInfo(record);
+  StringRef llvmClass = enumInfo.getLLVMClassName();
+  StringRef cppClassName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   // Emit the function converting the enum attribute to its LLVM counterpart.
   os << formatv(
@@ -419,7 +420,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
       llvmClass, cppClassName, cppNamespace);
   os << "  switch (value) {\n";
 
-  for (const auto &enumerant : enumAttr.getAllCases()) {
+  for (const auto &enumerant : enumInfo.getAllCases()) {
     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
     StringRef cppEnumerant = enumerant.getSymbol();
     os << formatv("  case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
@@ -429,7 +430,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
 
   os << "  }\n";
   os << formatv("  llvm_unreachable(\"unknown {0} type\");\n",
-                enumAttr.getEnumClassName());
+                enumInfo.getEnumClassName());
   os << "}\n\n";
 }
 
@@ -437,7 +438,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
 // (Enum) to the corresponding LLVM API C-style enumerant
 static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
-  LLVMCEnumAttr enumAttr(record);
+  LLVMCEnumInfo enumAttr(record);
   StringRef llvmClass = enumAttr.getLLVMClassName();
   StringRef cppClassName = enumAttr.getEnumClassName();
   StringRef cppNamespace = enumAttr.getCppNamespace();
@@ -467,10 +468,10 @@ static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
 // containing switch-based logic to convert from the LLVM API enumerant to MLIR
 // LLVM dialect enum attribute (Enum).
 static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
-  LLVMEnumAttr enumAttr(record);
-  StringRef llvmClass = enumAttr.getLLVMClassName();
-  StringRef cppClassName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  LLVMEnumInfo enumInfo(record);
+  StringRef llvmClass = enumInfo.getLLVMClassName();
+  StringRef cppClassName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   // Emit the function converting the enum attribute from its LLVM counterpart.
   os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} "
@@ -478,23 +479,23 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
                 cppNamespace, cppClassName, llvmClass);
   os << "  switch (value) {\n";
 
-  for (const auto &enumerant : enumAttr.getAllCases()) {
+  for (const auto &enumerant : enumInfo.getAllCases()) {
     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
     StringRef cppEnumerant = enumerant.getSymbol();
     os << formatv("  case {0}::{1}:\n", llvmClass, llvmEnumerant);
     os << formatv("    return {0}::{1}::{2};\n", cppNamespace, cppClassName,
                   cppEnumerant);
   }
-  for (const auto &enumerant : enumAttr.getAllUnsupportedCases()) {
+  for (const auto &enumerant : enumInfo.getAllUnsupportedCases()) {
     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
     os << formatv("  case {0}::{1}:\n", llvmClass, llvmEnumerant);
     os << formatv("    llvm_unreachable(\"unsupported case {0}::{1}\");\n",
-                  enumAttr.getLLVMClassName(), llvmEnumerant);
+                  enumInfo.getLLVMClassName(), llvmEnumerant);
   }
 
   os << "  }\n";
   os << formatv("  llvm_unreachable(\"unknown {0} type\");",
-                enumAttr.getLLVMClassName());
+                enumInfo.getLLVMClassName());
   os << "}\n\n";
 }
 
@@ -502,10 +503,10 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
 // containing switch-based logic to convert from the LLVM API C-style enumerant
 // to MLIR LLVM dialect enum attribute (Enum).
 static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
-  LLVMCEnumAttr enumAttr(record);
-  StringRef llvmClass = enumAttr.getLLVMClassName();
-  StringRef cppClassName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  LLVMCEnumInfo enumInfo(record);
+  StringRef llvmClass = enumInfo.getLLVMClassName();
+  StringRef cppClassName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   // Emit the function converting the enum attribute from its LLVM counterpart.
   os << formatv(
@@ -514,7 +515,7 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
       cppNamespace, cppClassName);
   os << "  switch (value) {\n";
 
-  for (const auto &enumerant : enumAttr.getAllCases()) {
+  for (const auto &enumerant : enumInfo.getAllCases()) {
     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
     StringRef cppEnumerant = enumerant.getSymbol();
     os << formatv("  case static_cast<int64_t>({0}::{1}):\n", llvmClass,
@@ -525,7 +526,7 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
 
   os << "  }\n";
   os << formatv("  llvm_unreachable(\"unknown {0} type\");",
-                enumAttr.getLLVMClassName());
+                enumInfo.getLLVMClassName());
   os << "}\n\n";
 }
 
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index dbaad84cda5d6..f53aebb302dc9 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Support/IndentedOstream.h"
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
@@ -384,14 +385,14 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &records, raw_ostream &os,
 // Enum Documentation
 //===----------------------------------------------------------------------===//
 
-static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
+static void emitEnumDoc(const EnumInfo &def, raw_ostream &os) {
   os << formatv("\n### {0}\n", def.getEnumClassName());
 
   // Emit the summary if present.
   emitSummary(def.getSummary(), os);
 
   // Emit case documentation.
-  std::vector<EnumAttrCase> cases = def.getAllCases();
+  std::vector<EnumCase> cases = def.getAllCases();
   os << "\n#### Cases:\n\n";
   os << "| Symbol | Value | String |\n"
      << "| :----: | :---: | ------ |";
@@ -406,7 +407,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
 static void emitEnumDoc(const RecordKeeper &records, raw_ostream &os) {
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
   for (const Record *def : records.getAllDerivedDefinitions("EnumAttrInfo"))
-    emitEnumDoc(EnumAttr(def), os);
+    emitEnumDoc(EnumInfo(def), os);
 }
 
 //===----------------------------------------------------------------------===//
@@ -441,7 +442,7 @@ static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
 static void emitBlock(ArrayRef<Attribute> attributes, StringRef inputFilename,
                       ArrayRef<AttrDef> attrDefs, ArrayRef<OpDocGroup> ops,
                       ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
-                      ArrayRef<EnumAttr> enums, raw_ostream &os) {
+                      ArrayRef<EnumInfo> enums, raw_ostream &os) {
   if (!ops.empty()) {
     os << "\n## Operations\n";
     emitSourceLink(inputFilename, os);
@@ -490,7 +491,7 @@ static void emitBlock(ArrayRef<Attribute> attributes, StringRef inputFilename,
 
   if (!enums.empty()) {
     os << "\n## Enums\n";
-    for (const EnumAttr &def : enums)
+    for (const EnumInfo &def : enums)
       emitEnumDoc(def, os);
   }
 }
@@ -499,7 +500,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
                            ArrayRef<Attribute> attributes,
                            ArrayRef<AttrDef> attrDefs, ArrayRef<OpDocGroup> ops,
                            ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
-                           ArrayRef<EnumAttr> enums, raw_ostream &os) {
+                           ArrayRef<EnumInfo> enums, raw_ostream &os) {
   os << "\n# '" << dialect.getName() << "' Dialect\n";
   emitSummary(dialect.getSummary(), os);
   emitDescription(dialect.getDescription(), os);
@@ -532,7 +533,7 @@ static bool emitDialectDoc(const RecordKeeper &records, raw_ostream &os) {
   std::vector<OpDocGroup> dialectOps;
   std::vector<Type> dialectTypes;
   std::vector<TypeDef> dialectTypeDefs;
-  std::vector<EnumAttr> dialectEnums;
+  std::vector<EnumInfo> dialectEnums;
 
   SmallDenseSet<const Record *> seen;
   auto addIfNotSeen = [&](const Record *record, const auto &def, auto &vec) {
@@ -576,7 +577,7 @@ static bool emitDialectDoc(const RecordKeeper &records, raw_ostream &os) {
     addIfInDialect(def, Type(def), dialectTypes);
   dialectEnums.reserve(enumDefs.size());
   for (const Record *def : enumDefs)
-    addIfNotSeen(def, EnumAttr(def), dialectEnums);
+    addIfNotSeen(def, EnumInfo(def), dialectEnums);
 
   // Sort alphabetically ignorning dialect for ops and section name for
   // sections.
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index fe724e86d6707..3a7a7aaf3a5dd 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -11,6 +11,7 @@
 #include "OpClass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Class.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/Operator.h"
 #include "mlir/TableGen/Trait.h"
@@ -424,17 +425,17 @@ struct OperationFormat {
 //===----------------------------------------------------------------------===//
 // Parser Gen
 
-/// Returns true if we can format the given attribute as an EnumAttr in the
+/// Returns true if we can format the given attribute as an enum in the
 /// parser format.
 static bool canFormatEnumAttr(const NamedAttribute *attr) {
   Attribute baseAttr = attr->attr.getBaseAttr();
-  const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
-  if (!enumAttr)
+  if (!baseAttr.isEnumAttr())
     return false;
+  EnumInfo enumInfo(&baseAttr.getDef());
 
   // The attribute must have a valid underlying type and a constant builder.
-  return !enumAttr->getUnderlyingType().empty() &&
-         !enumAttr->getConstBuilderTemplate().empty();
+  return !enumInfo.getUnderlyingType().empty() &&
+         !baseAttr.getConstBuilderTemplate().empty();
 }
 
 /// Returns if we should format the given attribute as an SymbolNameAttr.
@@ -1150,21 +1151,21 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
                               FmtContext &attrTypeCtx, bool parseAsOptional,
                               bool useProperties, StringRef opCppClassName) {
   Attribute baseAttr = var->attr.getBaseAttr();
-  const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
-  std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
+  EnumInfo enumInfo(&baseAttr.getDef());
+  std::vector<EnumCase> cases = enumInfo.getAllCases();
 
   // Generate the code for building an attribute for this enum.
   std::string attrBuilderStr;
   {
     llvm::raw_string_ostream os(attrBuilderStr);
-    os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
+    os << tgfmt(baseAttr.getConstBuilderTemplate(), &attrTypeCtx,
                 "*attrOptional");
   }
 
   // Build a string containing the cases that can be formatted as a keyword.
   std::string validCaseKeywordsStr = "{";
   llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
-  for (const EnumAttrCase &attrCase : cases)
+  for (const EnumCase &attrCase : cases)
     if (canFormatStringAsKeyword(attrCase.getStr()))
       validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
   validCaseKeywordsOS.str().back() = '}';
@@ -1194,8 +1195,8 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
         formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name);
   }
 
-  body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
-                  enumAttr.getStringToSymbolFnName(), attrBuilderStr,
+  body << formatv(enumAttrParserCode, var->name, enumInfo.getCppNamespace(),
+                  enumInfo.getStringToSymbolFnName(), attrBuilderStr,
                   validCaseKeywordsStr, errorMessage, attrAssignment);
 }
 
@@ -2264,13 +2265,13 @@ static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
                                MethodBody &body) {
   Attribute baseAttr = var->attr.getBaseAttr();
-  const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
-  std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
+  const EnumInfo enumInfo(&baseAttr.getDef());
+  std::vector<EnumCase> cases = enumInfo.getAllCases();
 
   body << formatv(enumAttrBeginPrinterCode,
                   (var->attr.isOptional() ? "*" : "") +
                       op.getGetterName(var->name),
-                  enumAttr.getSymbolToStringFnName());
+                  enumInfo.getSymbolToStringFnName());
 
   // Get a string containing all of the cases that can't be represented with a
   // keyword.
@@ -2283,7 +2284,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
   // Otherwise if this is a bit enum attribute, don't allow cases that may
   // overlap with other cases. For simplicity sake, only allow cases with a
   // single bit value.
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     for (auto it : llvm::enumerate(cases)) {
       int64_t value = it.value().getValue();
       if (value < 0 || !llvm::isPowerOf2_64(value))
@@ -2295,8 +2296,8 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
   // case value to determine when to print in the string form.
   if (nonKeywordCases.any()) {
     body << "    switch (caseValue) {\n";
-    StringRef cppNamespace = enumAttr.getCppNamespace();
-    StringRef enumName = enumAttr.getEnumClassName();
+    StringRef cppNamespace = enumInfo.getCppNamespace();
+    StringRef enumName = enumInfo.getEnumClassName();
     for (auto it : llvm::enumerate(cases)) {
       if (nonKeywordCases.test(it.index()))
         continue;
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index c74cb9943671e..c8a12b9e21b90 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1377,12 +1377,12 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
     return handleConstantAttr(constAttr.getAttribute(),
                               constAttr.getConstantValue());
   }
-  if (leaf.isEnumAttrCase()) {
-    auto enumCase = leaf.getAsEnumAttrCase();
+  if (leaf.isEnumCase()) {
+    auto enumCase = leaf.getAsEnumCase();
     // This is an enum case backed by an IntegerAttr. We need to get its value
     // to build the constant.
     std::string val = std::to_string(enumCase.getValue());
-    return handleConstantAttr(enumCase, val);
+    return handleConstantAttr(Attribute(&enumCase.getDef()), val);
   }
 
   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
@@ -1782,7 +1782,7 @@ void PatternEmitter::supplyValuesForOpArgs(
       auto leaf = node.getArgAsLeaf(argIndex);
       // The argument in the result DAG pattern.
       auto patArgName = node.getArgName(argIndex);
-      if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
+      if (leaf.isConstantAttr() || leaf.isEnumCase()) {
         // TODO: Refactor out into map to avoid recomputing these.
         if (!isa<NamedAttribute *>(opArg))
           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 75b8829be4da5..7a6189c09f426 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
@@ -45,8 +46,8 @@ using llvm::SMLoc;
 using llvm::StringMap;
 using llvm::StringRef;
 using mlir::tblgen::Attribute;
-using mlir::tblgen::EnumAttr;
-using mlir::tblgen::EnumAttrCase;
+using mlir::tblgen::EnumCase;
+using mlir::tblgen::EnumInfo;
 using mlir::tblgen::NamedAttribute;
 using mlir::tblgen::NamedTypeConstraint;
 using mlir::tblgen::NamespaceEmitter;
@@ -335,18 +336,18 @@ static mlir::GenRegistration
 
 static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
                                             raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::vector<EnumCase> enumerants = enumInfo.getAllCases();
 
   // Mapping from availability class name to (enumerant, availability
   // specification) pairs.
-  llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
+  llvm::StringMap<llvm::SmallVector<std::pair<EnumCase, Availability>, 1>>
       classCaseMap;
 
   // Place all availability specifications to their corresponding
   // availability classes.
-  for (const EnumAttrCase &enumerant : enumerants)
+  for (const EnumCase &enumerant : enumerants)
     for (const Availability &avail : getAvailabilities(enumerant.getDef()))
       classCaseMap[avail.getClass()].push_back({enumerant, avail});
 
@@ -359,14 +360,14 @@ static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
 
     os << "  switch (value) {\n";
     for (const auto &caseSpecPair : classCasePair.getValue()) {
-      EnumAttrCase enumerant = caseSpecPair.first;
+      EnumCase enumerant = caseSpecPair.first;
       Availability avail = caseSpecPair.second;
       os << formatv("  case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
                     enumerant.getSymbol(), avail.getMergeInstancePreparation(),
                     avail.getMergeInstanceType(), avail.getMergeInstance());
     }
     // Only emit default if uncovered cases.
-    if (classCasePair.getValue().size() < enumAttr.getAllCases().size())
+    if (classCasePair.getValue().size() < enumInfo.getAllCases().size())
       os << "  default: break;\n";
     os << "  }\n"
        << "  return std::nullopt;\n"
@@ -376,19 +377,19 @@ static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
 
 static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
                                             raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  std::vector<EnumCase> enumerants = enumInfo.getAllCases();
 
   // Mapping from availability class name to (enumerant, availability
   // specification) pairs.
-  llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
+  llvm::StringMap<llvm::SmallVector<std::pair<EnumCase, Availability>, 1>>
       classCaseMap;
 
   // Place all availability specifications to their corresponding
   // availability classes.
-  for (const EnumAttrCase &enumerant : enumerants)
+  for (const EnumCase &enumerant : enumerants)
     for (const Availability &avail : getAvailabilities(enumerant.getDef()))
       classCaseMap[avail.getClass()].push_back({enumerant, avail});
 
@@ -406,7 +407,7 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
 
     os << "  switch (value) {\n";
     for (const auto &caseSpecPair : classCasePair.getValue()) {
-      EnumAttrCase enumerant = caseSpecPair.first;
+      EnumCase enumerant = caseSpecPair.first;
       Availability avail = caseSpecPair.second;
       os << formatv("  case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
                     enumerant.getSymbol(), avail.getMergeInstancePreparation(),
@@ -420,10 +421,10 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
 }
 
 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
+  auto enumerants = enumInfo.getAllCases();
 
   llvm::SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
@@ -435,7 +436,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
 
   // Place all availability specifications to their corresponding
   // availability classes.
-  for (const EnumAttrCase &enumerant : enumerants)
+  for (const EnumCase &enumerant : enumerants)
     for (const Availability &avail : getAvailabilities(enumerant.getDef())) {
       StringRef className = avail.getClass();
       if (handledClasses.count(className))
@@ -462,8 +463,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
 }
 
 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  EnumInfo enumInfo(enumDef);
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   llvm::SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
@@ -471,7 +472,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
   for (auto ns : namespaces)
     os << "namespace " << ns << " {\n";
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     emitAvailabilityQueryForBitEnum(enumDef, os);
   } else {
     emitAvailabilityQueryForIntEnum(enumDef, os);
@@ -535,7 +536,7 @@ static void emitAttributeSerialization(const Attribute &attr,
   os << tabs
      << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
   if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
-    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    EnumInfo baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("  {0}.push_back(prepareConstantInt({1}.getLoc(), "
                   "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
@@ -544,7 +545,7 @@ static void emitAttributeSerialization(const Attribute &attr,
                   baseEnum.getEnumClassName());
   } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
              attr.isSubClassOf("SPIRV_I32EnumAttr")) {
-    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    EnumInfo baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("  {0}.push_back(static_cast<uint32_t>("
                   "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n",
@@ -831,7 +832,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
                                          StringRef words, StringRef wordIndex,
                                          raw_ostream &os) {
   if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
-    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    EnumInfo baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
@@ -840,7 +841,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
                   baseEnum.getEnumClassName(), words, wordIndex);
   } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
              attr.isSubClassOf("SPIRV_I32EnumAttr")) {
-    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    EnumInfo baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("  {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "opBuilder.getAttr<{2}::{3}Attr>("
@@ -1246,9 +1247,9 @@ static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
                 "attributeName();\n");
 }
 
-static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
+static void emitEnumGetAttrNameFnDefn(const EnumInfo &enumInfo,
                                       raw_ostream &os) {
-  auto enumName = enumAttr.getEnumClassName();
+  auto enumName = enumInfo.getEnumClassName();
   os << formatv("template <> inline StringRef attributeName<{0}>() {{\n",
                 enumName);
   os << "  "
@@ -1266,8 +1267,8 @@ static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) {
   os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
   emitEnumGetAttrNameFnDecl(os);
   for (const auto *def : defs) {
-    EnumAttr enumAttr(*def);
-    emitEnumGetAttrNameFnDefn(enumAttr, os);
+    EnumInfo enumInfo(*def);
+    emitEnumGetAttrNameFnDefn(enumInfo, os);
   }
   os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
   return false;
@@ -1306,9 +1307,9 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
     if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
         !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
       continue;
-    EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
+    EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef("enum"));
 
-    for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
+    for (const EnumCase &enumerant : enumInfo.getAllCases())
       for (const Availability &caseAvail :
            getAvailabilities(enumerant.getDef()))
         availClasses.try_emplace(caseAvail.getClass(), caseAvail);
@@ -1348,14 +1349,14 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
       if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
           !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
         continue;
-      EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
+      EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef("enum"));
 
       // (enumerant, availability specification) pairs for this availability
       // class.
-      SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
+      SmallVector<std::pair<EnumCase, Availability>, 1> caseSpecs;
 
       // Collect all cases' availability specs.
-      for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
+      for (const EnumCase &enumerant : enumInfo.getAllCases())
         for (const Availability &caseAvail :
              getAvailabilities(enumerant.getDef()))
           if (availClassName == caseAvail.getClass())
@@ -1366,19 +1367,19 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
       if (caseSpecs.empty())
         continue;
 
-      if (enumAttr.isBitEnum()) {
+      if (enumInfo.isBitEnum()) {
         // For BitEnumAttr, we need to iterate over each bit to query its
         // availability spec.
         os << formatv("  for (unsigned i = 0; "
                       "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
-                      enumAttr.getUnderlyingType());
+                      enumInfo.getUnderlyingType());
         os << formatv("    {0}::{1} tblgen_attrVal = this->{2}() & "
                       "static_cast<{0}::{1}>(1 << i);\n",
-                      enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
+                      enumInfo.getCppNamespace(), enumInfo.getEnumClassName(),
                       srcOp.getGetterName(namedAttr.name));
         os << formatv(
             "    if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
-            enumAttr.getUnderlyingType());
+            enumInfo.getUnderlyingType());
       } else {
         // For IntEnumAttr, we just need to query the value as a whole.
         os << "  {\n";
@@ -1386,7 +1387,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
                       srcOp.getGetterName(namedAttr.name));
       }
       os << formatv("    auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
-                    enumAttr.getCppNamespace(), avail.getQueryFnName());
+                    enumInfo.getCppNamespace(), avail.getQueryFnName());
       os << "    if (tblgen_instance) "
          // TODO` here once ODS supports
          // dialect-specific contents so that we can use not implementing the
@@ -1434,14 +1435,14 @@ static bool emitCapabilityImplication(const RecordKeeper &records,
                                       raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Capability Implication", os, records);
 
-  EnumAttr enumAttr(
+  EnumInfo enumInfo(
       records.getDef("SPIRV_CapabilityAttr")->getValueAsDef("enum"));
 
   os << "ArrayRef<spirv::Capability> "
         "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
      << "  switch (cap) {\n"
      << "  default: return {};\n";
-  for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) {
+  for (const EnumCase &enumerant : enumInfo.getAllCases()) {
     const Record &def = enumerant.getDef();
     if (!def.getValue("implies"))
       continue;
@@ -1452,7 +1453,7 @@ static bool emitCapabilityImplication(const RecordKeeper &records,
        << ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
        << "] = {";
     llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) {
-      os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol();
+      os << "spirv::Capability::" << EnumCase(capDef).getSymbol();
     });
     os << "}; return ArrayRef<spirv::Capability>(implies, "
        << impliedCapsDefs.size() << "); }\n";
diff --git a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp
index 491f9143edb02..ddc149810ebd8 100644
--- a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
@@ -42,8 +43,8 @@ using llvm::SMLoc;
 using llvm::StringMap;
 using llvm::StringRef;
 using mlir::tblgen::Attribute;
-using mlir::tblgen::EnumAttr;
-using mlir::tblgen::EnumAttrCase;
+using mlir::tblgen::EnumCase;
+using mlir::tblgen::EnumInfo;
 using mlir::tblgen::NamedAttribute;
 using mlir::tblgen::NamedTypeConstraint;
 using mlir::tblgen::NamespaceEmitter;



More information about the Mlir-commits mailing list