[Mlir-commits] [mlir] [mlir] Decouple enum generation from attributes, adding EnumInfo and EnumCase (PR #132148)

Krzysztof Drewniak llvmlistbot at llvm.org
Sun Mar 23 18:33:07 PDT 2025


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/132148

>From d130635269f41fa489fb51ea1aae8d2067c4ccca Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Sun, 23 Mar 2025 17:01:13 -0700
Subject: [PATCH 1/2] [mlir][NFC] Move and rename EnumAttrCas, EnumAttr C++
 classes

This moves the EnumAttrCase and EnumAttr classes from Attribute.h/.cpp
to a new EnumInfo.h/cpp and renames them to EnumCase and EnumInfo,
respectively.

This doesn't change any of the tablegen files or any user-facing
aspects of the enum attribute generation system, just reorganizes code
in order to make main PR (#132148) shorter.
---
 mlir/include/mlir/TableGen/Attribute.h        |  74 --------
 mlir/include/mlir/TableGen/EnumInfo.h         | 135 ++++++++++++++
 mlir/include/mlir/TableGen/Pattern.h          |  11 +-
 mlir/lib/TableGen/Attribute.cpp               |  94 ----------
 mlir/lib/TableGen/CMakeLists.txt              |   1 +
 mlir/lib/TableGen/EnumInfo.cpp                | 130 ++++++++++++++
 mlir/lib/TableGen/Pattern.cpp                 |  12 +-
 .../mlir-tblgen/EnumPythonBindingGen.cpp      |  47 ++---
 mlir/tools/mlir-tblgen/EnumsGen.cpp           | 166 +++++++++---------
 .../tools/mlir-tblgen/LLVMIRConversionGen.cpp |  75 ++++----
 mlir/tools/mlir-tblgen/OpDocGen.cpp           |  17 +-
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        |  35 ++--
 mlir/tools/mlir-tblgen/RewriterGen.cpp        |   8 +-
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |  91 +++++-----
 mlir/tools/mlir-tblgen/TosaUtilsGen.cpp       |   5 +-
 15 files changed, 506 insertions(+), 395 deletions(-)
 create mode 100644 mlir/include/mlir/TableGen/EnumInfo.h
 create mode 100644 mlir/lib/TableGen/EnumInfo.cpp

diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 62720e74849fc..dee81880bacab 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -16,7 +16,6 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Constraint.h"
-#include "llvm/ADT/StringRef.h"
 
 namespace llvm {
 class DefInit;
@@ -136,79 +135,6 @@ class ConstantAttr {
   const llvm::Record *def;
 };
 
-// Wrapper class providing helper methods for accessing enum attribute cases
-// defined in TableGen. This is used for enum attribute case backed by both
-// StringAttr and IntegerAttr.
-class EnumAttrCase : public Attribute {
-public:
-  explicit EnumAttrCase(const llvm::Record *record);
-  explicit EnumAttrCase(const llvm::DefInit *init);
-
-  // Returns the symbol of this enum attribute case.
-  StringRef getSymbol() const;
-
-  // Returns the textual representation of this enum attribute case.
-  StringRef getStr() const;
-
-  // Returns the value of this enum attribute case.
-  int64_t getValue() const;
-
-  // Returns the TableGen definition this EnumAttrCase was constructed from.
-  const llvm::Record &getDef() const;
-};
-
-// Wrapper class providing helper methods for accessing enum attributes defined
-// in TableGen.This is used for enum attribute case backed by both StringAttr
-// and IntegerAttr.
-class EnumAttr : public Attribute {
-public:
-  explicit EnumAttr(const llvm::Record *record);
-  explicit EnumAttr(const llvm::Record &record);
-  explicit EnumAttr(const llvm::DefInit *init);
-
-  static bool classof(const Attribute *attr);
-
-  // Returns true if this is a bit enum attribute.
-  bool isBitEnum() const;
-
-  // Returns the enum class name.
-  StringRef getEnumClassName() const;
-
-  // Returns the C++ namespaces this enum class should be placed in.
-  StringRef getCppNamespace() const;
-
-  // Returns the underlying type.
-  StringRef getUnderlyingType() const;
-
-  // Returns the name of the utility function that converts a value of the
-  // underlying type to the corresponding symbol.
-  StringRef getUnderlyingToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a string to the
-  // corresponding symbol.
-  StringRef getStringToSymbolFnName() const;
-
-  // Returns the name of the utility function that converts a symbol to the
-  // corresponding string.
-  StringRef getSymbolToStringFnName() const;
-
-  // Returns the return type of the utility function that converts a symbol to
-  // the corresponding string.
-  StringRef getSymbolToStringFnRetType() const;
-
-  // Returns the name of the utilit function that returns the max enum value
-  // used within the enum class.
-  StringRef getMaxEnumValFnName() const;
-
-  // Returns all allowed cases for this enum attribute.
-  std::vector<EnumAttrCase> getAllCases() const;
-
-  bool genSpecializedAttr() const;
-  const llvm::Record *getBaseAttrClass() const;
-  StringRef getSpecializedAttrClassName() const;
-  bool printBitEnumPrimaryGroups() const;
-};
-
 // Name of infer type op interface.
 extern const char *inferTypeOpInterface;
 
diff --git a/mlir/include/mlir/TableGen/EnumInfo.h b/mlir/include/mlir/TableGen/EnumInfo.h
new file mode 100644
index 0000000000000..196267864f325
--- /dev/null
+++ b/mlir/include/mlir/TableGen/EnumInfo.h
@@ -0,0 +1,135 @@
+//===- EnumInfo.h - EnumInfo wrapper class --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// EnumInfo wrapper to simplify using a TableGen Record defining an Enum
+// via EnumInfo and its `EnumCase`s.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TABLEGEN_ENUMINFO_H_
+#define MLIR_TABLEGEN_ENUMINFO_H_
+
+#include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace llvm {
+class DefInit;
+class Record;
+} // namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+// Wrapper class providing around enum cases defined in TableGen.
+class EnumCase {
+public:
+  explicit EnumCase(const llvm::Record *record);
+  explicit EnumCase(const llvm::DefInit *init);
+
+  // Returns the symbol of this enum attribute case.
+  StringRef getSymbol() const;
+
+  // Returns the textual representation of this enum attribute case.
+  StringRef getStr() const;
+
+  // Returns the value of this enum attribute case.
+  int64_t getValue() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+// Wrapper class providing helper methods for accessing enums defined
+// in TableGen using EnumInfo. Some methods are only applicable when
+// the enum is also an attribute, or only when it is a bit enum.
+class EnumInfo {
+public:
+  explicit EnumInfo(const llvm::Record *record);
+  explicit EnumInfo(const llvm::Record &record);
+  explicit EnumInfo(const llvm::DefInit *init);
+
+  // Returns true if the given EnumInfo is a subclass of the named TableGen
+  // class.
+  bool isSubClassOf(StringRef className) const;
+
+  // Returns true if this enum is an EnumAttrInfo, thus making it define an
+  // attribute.
+  bool isEnumAttr() const;
+
+  // Create the `Attribute` wrapper around this EnumInfo if it is defining an
+  // attribute.
+  std::optional<Attribute> asEnumAttr() const;
+
+  // Returns true if this is a bit enum.
+  bool isBitEnum() const;
+
+  // Returns the enum class name.
+  StringRef getEnumClassName() const;
+
+  // Returns the C++ namespaces this enum class should be placed in.
+  StringRef getCppNamespace() const;
+
+  // Returns the summary of the enum.
+  StringRef getSummary() const;
+
+  // Returns the description of the enum.
+  StringRef getDescription() const;
+
+  // Returns the underlying type.
+  StringRef getUnderlyingType() const;
+
+  // Returns the name of the utility function that converts a value of the
+  // underlying type to the corresponding symbol.
+  StringRef getUnderlyingToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a string to the
+  // corresponding symbol.
+  StringRef getStringToSymbolFnName() const;
+
+  // Returns the name of the utility function that converts a symbol to the
+  // corresponding string.
+  StringRef getSymbolToStringFnName() const;
+
+  // Returns the return type of the utility function that converts a symbol to
+  // the corresponding string.
+  StringRef getSymbolToStringFnRetType() const;
+
+  // Returns the name of the utilit function that returns the max enum value
+  // used within the enum class.
+  StringRef getMaxEnumValFnName() const;
+
+  // Returns all allowed cases for this enum attribute.
+  std::vector<EnumCase> getAllCases() const;
+
+  // Only applicable for enum attributes.
+
+  bool genSpecializedAttr() const;
+  const llvm::Record *getBaseAttrClass() const;
+  StringRef getSpecializedAttrClassName() const;
+
+  // Only applicable for bit enums.
+
+  bool printBitEnumPrimaryGroups() const;
+
+  // Returns the TableGen definition this EnumAttrCase was constructed from.
+  const llvm::Record &getDef() const;
+
+protected:
+  // The TableGen definition of this constraint.
+  const llvm::Record *def;
+};
+
+} // namespace tblgen
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h
index 80f38fdeffee0..1c9e128f0a0fb 100644
--- a/mlir/include/mlir/TableGen/Pattern.h
+++ b/mlir/include/mlir/TableGen/Pattern.h
@@ -16,6 +16,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Argument.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Hashing.h"
@@ -78,8 +79,8 @@ class DagLeaf {
   // Returns true if this DAG leaf is specifying a constant attribute.
   bool isConstantAttr() const;
 
-  // Returns true if this DAG leaf is specifying an enum attribute case.
-  bool isEnumAttrCase() const;
+  // Returns true if this DAG leaf is specifying an enum case.
+  bool isEnumCase() const;
 
   // Returns true if this DAG leaf is specifying a string attribute.
   bool isStringAttr() const;
@@ -90,9 +91,9 @@ class DagLeaf {
   // Returns this DAG leaf as an constant attribute. Asserts if fails.
   ConstantAttr getAsConstantAttr() const;
 
-  // Returns this DAG leaf as an enum attribute case.
-  // Precondition: isEnumAttrCase()
-  EnumAttrCase getAsEnumAttrCase() const;
+  // Returns this DAG leaf as an enum case.
+  // Precondition: isEnumCase()
+  EnumCase getAsEnumCase() const;
 
   // Returns the matching condition template inside this DAG leaf. Assumes the
   // leaf is an operand/attribute matcher and asserts otherwise.
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index f9fc58a40f334..142d194260942 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -146,98 +146,4 @@ StringRef ConstantAttr::getConstantValue() const {
   return def->getValueAsString("value");
 }
 
-EnumAttrCase::EnumAttrCase(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrCaseInfo") &&
-         "must be subclass of TableGen 'EnumAttrInfo' class");
-}
-
-EnumAttrCase::EnumAttrCase(const DefInit *init)
-    : EnumAttrCase(init->getDef()) {}
-
-StringRef EnumAttrCase::getSymbol() const {
-  return def->getValueAsString("symbol");
-}
-
-StringRef EnumAttrCase::getStr() const { return def->getValueAsString("str"); }
-
-int64_t EnumAttrCase::getValue() const { return def->getValueAsInt("value"); }
-
-const Record &EnumAttrCase::getDef() const { return *def; }
-
-EnumAttr::EnumAttr(const Record *record) : Attribute(record) {
-  assert(isSubClassOf("EnumAttrInfo") &&
-         "must be subclass of TableGen 'EnumAttr' class");
-}
-
-EnumAttr::EnumAttr(const Record &record) : Attribute(&record) {}
-
-EnumAttr::EnumAttr(const DefInit *init) : EnumAttr(init->getDef()) {}
-
-bool EnumAttr::classof(const Attribute *attr) {
-  return attr->isSubClassOf("EnumAttrInfo");
-}
-
-bool EnumAttr::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
-
-StringRef EnumAttr::getEnumClassName() const {
-  return def->getValueAsString("className");
-}
-
-StringRef EnumAttr::getCppNamespace() const {
-  return def->getValueAsString("cppNamespace");
-}
-
-StringRef EnumAttr::getUnderlyingType() const {
-  return def->getValueAsString("underlyingType");
-}
-
-StringRef EnumAttr::getUnderlyingToSymbolFnName() const {
-  return def->getValueAsString("underlyingToSymbolFnName");
-}
-
-StringRef EnumAttr::getStringToSymbolFnName() const {
-  return def->getValueAsString("stringToSymbolFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnName() const {
-  return def->getValueAsString("symbolToStringFnName");
-}
-
-StringRef EnumAttr::getSymbolToStringFnRetType() const {
-  return def->getValueAsString("symbolToStringFnRetType");
-}
-
-StringRef EnumAttr::getMaxEnumValFnName() const {
-  return def->getValueAsString("maxEnumValFnName");
-}
-
-std::vector<EnumAttrCase> EnumAttr::getAllCases() const {
-  const auto *inits = def->getValueAsListInit("enumerants");
-
-  std::vector<EnumAttrCase> cases;
-  cases.reserve(inits->size());
-
-  for (const Init *init : *inits) {
-    cases.emplace_back(cast<DefInit>(init));
-  }
-
-  return cases;
-}
-
-bool EnumAttr::genSpecializedAttr() const {
-  return def->getValueAsBit("genSpecializedAttr");
-}
-
-const Record *EnumAttr::getBaseAttrClass() const {
-  return def->getValueAsDef("baseAttrClass");
-}
-
-StringRef EnumAttr::getSpecializedAttrClassName() const {
-  return def->getValueAsString("specializedAttrClassName");
-}
-
-bool EnumAttr::printBitEnumPrimaryGroups() const {
-  return def->getValueAsBit("printBitEnumPrimaryGroups");
-}
-
 const char * ::mlir::tblgen::inferTypeOpInterface = "InferTypeOpInterface";
diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt
index c4104e644147c..a90c55847718e 100644
--- a/mlir/lib/TableGen/CMakeLists.txt
+++ b/mlir/lib/TableGen/CMakeLists.txt
@@ -20,6 +20,7 @@ llvm_add_library(MLIRTableGen STATIC
   CodeGenHelpers.cpp
   Constraint.cpp
   Dialect.cpp
+  EnumInfo.cpp
   Format.cpp
   GenInfo.cpp
   Interfaces.cpp
diff --git a/mlir/lib/TableGen/EnumInfo.cpp b/mlir/lib/TableGen/EnumInfo.cpp
new file mode 100644
index 0000000000000..9f491d30f0e7f
--- /dev/null
+++ b/mlir/lib/TableGen/EnumInfo.cpp
@@ -0,0 +1,130 @@
+//===- EnumInfo.cpp - EnumInfo wrapper class ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/EnumInfo.h"
+#include "mlir/TableGen/Attribute.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+using llvm::DefInit;
+using llvm::Init;
+using llvm::Record;
+
+EnumCase::EnumCase(const Record *record) : def(record) {
+  assert(def->isSubClassOf("EnumAttrCaseInfo") &&
+         "must be subclass of TableGen 'EnumAttrCaseInfo' class");
+}
+
+EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {}
+
+StringRef EnumCase::getSymbol() const {
+  return def->getValueAsString("symbol");
+}
+
+StringRef EnumCase::getStr() const { return def->getValueAsString("str"); }
+
+int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); }
+
+const Record &EnumCase::getDef() const { return *def; }
+
+EnumInfo::EnumInfo(const Record *record) : def(record) {
+  assert(isSubClassOf("EnumAttrInfo") &&
+         "must be subclass of TableGen 'EnumAttrInfo' class");
+}
+
+EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {}
+
+EnumInfo::EnumInfo(const DefInit *init) : EnumInfo(init->getDef()) {}
+
+bool EnumInfo::isSubClassOf(StringRef className) const {
+  return def->isSubClassOf(className);
+}
+
+bool EnumInfo::isEnumAttr() const { return isSubClassOf("EnumAttrInfo"); }
+
+std::optional<Attribute> EnumInfo::asEnumAttr() const {
+  if (isEnumAttr())
+    return Attribute(def);
+  return std::nullopt;
+}
+
+bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
+
+StringRef EnumInfo::getEnumClassName() const {
+  return def->getValueAsString("className");
+}
+
+StringRef EnumInfo::getSummary() const {
+  return def->getValueAsString("summary");
+}
+
+StringRef EnumInfo::getDescription() const {
+  return def->getValueAsString("description");
+}
+
+StringRef EnumInfo::getCppNamespace() const {
+  return def->getValueAsString("cppNamespace");
+}
+
+StringRef EnumInfo::getUnderlyingType() const {
+  return def->getValueAsString("underlyingType");
+}
+
+StringRef EnumInfo::getUnderlyingToSymbolFnName() const {
+  return def->getValueAsString("underlyingToSymbolFnName");
+}
+
+StringRef EnumInfo::getStringToSymbolFnName() const {
+  return def->getValueAsString("stringToSymbolFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnName() const {
+  return def->getValueAsString("symbolToStringFnName");
+}
+
+StringRef EnumInfo::getSymbolToStringFnRetType() const {
+  return def->getValueAsString("symbolToStringFnRetType");
+}
+
+StringRef EnumInfo::getMaxEnumValFnName() const {
+  return def->getValueAsString("maxEnumValFnName");
+}
+
+std::vector<EnumCase> EnumInfo::getAllCases() const {
+  const auto *inits = def->getValueAsListInit("enumerants");
+
+  std::vector<EnumCase> cases;
+  cases.reserve(inits->size());
+
+  for (const Init *init : *inits) {
+    cases.emplace_back(cast<DefInit>(init));
+  }
+
+  return cases;
+}
+
+bool EnumInfo::genSpecializedAttr() const {
+  return isSubClassOf("EnumAttrInfo") &&
+         def->getValueAsBit("genSpecializedAttr");
+}
+
+const Record *EnumInfo::getBaseAttrClass() const {
+  return def->getValueAsDef("baseAttrClass");
+}
+
+StringRef EnumInfo::getSpecializedAttrClassName() const {
+  return def->getValueAsString("specializedAttrClassName");
+}
+
+bool EnumInfo::printBitEnumPrimaryGroups() const {
+  return def->getValueAsBit("printBitEnumPrimaryGroups");
+}
+
+const Record &EnumInfo::getDef() const { return *def; }
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index ac8c49c72d384..73e2803c21dae 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -57,9 +57,7 @@ bool DagLeaf::isNativeCodeCall() const {
 
 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
 
-bool DagLeaf::isEnumAttrCase() const {
-  return isSubClassOf("EnumAttrCaseInfo");
-}
+bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); }
 
 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
 
@@ -74,9 +72,9 @@ ConstantAttr DagLeaf::getAsConstantAttr() const {
   return ConstantAttr(cast<DefInit>(def));
 }
 
-EnumAttrCase DagLeaf::getAsEnumAttrCase() const {
-  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
-  return EnumAttrCase(cast<DefInit>(def));
+EnumCase DagLeaf::getAsEnumCase() const {
+  assert(isEnumCase() && "the DAG leaf must be an enum attribute case");
+  return EnumCase(cast<DefInit>(def));
 }
 
 std::string DagLeaf::getConditionTemplate() const {
@@ -776,7 +774,7 @@ void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
           verifyBind(infoMap.bindValue(treeArgName), treeArgName);
         } else {
           auto constraint = leaf.getAsConstraint();
-          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() ||
+          bool isAttr = leaf.isAttrMatcher() || leaf.isEnumCase() ||
                         leaf.isConstantAttr() ||
                         constraint.getKind() == Constraint::Kind::CK_Attr;
 
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 3f660ae151c74..5d4d9e90fff67 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -15,6 +15,7 @@
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Dialect.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/TableGen/Record.h"
@@ -44,14 +45,14 @@ static std::string makePythonEnumCaseName(StringRef name) {
 }
 
 /// Emits the Python class for the given enum.
-static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
-  os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
-                enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
-  if (!enumAttr.getSummary().empty())
-    os << formatv("    \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
+static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
+  os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(),
+                enumInfo.isBitEnum() ? "IntFlag" : "IntEnum");
+  if (!enumInfo.getSummary().empty())
+    os << formatv("    \"\"\"{0}\"\"\"\n", enumInfo.getSummary());
   os << "\n";
 
-  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
+  for (const EnumCase &enumCase : enumInfo.getAllCases()) {
     os << formatv("    {0} = {1}\n",
                   makePythonEnumCaseName(enumCase.getSymbol()),
                   enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
@@ -60,7 +61,7 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
 
   os << "\n";
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     os << formatv("    def __iter__(self):\n"
                   "        return iter([case for case in type(self) if "
                   "(self & case) is case])\n");
@@ -70,17 +71,17 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
   }
 
   os << formatv("    def __str__(self):\n");
-  if (enumAttr.isBitEnum())
+  if (enumInfo.isBitEnum())
     os << formatv("        if len(self) > 1:\n"
                   "            return \"{0}\".join(map(str, self))\n",
-                  enumAttr.getDef().getValueAsString("separator"));
-  for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
-    os << formatv("        if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
+                  enumInfo.getDef().getValueAsString("separator"));
+  for (const EnumCase &enumCase : enumInfo.getAllCases()) {
+    os << formatv("        if self is {0}.{1}:\n", enumInfo.getEnumClassName(),
                   makePythonEnumCaseName(enumCase.getSymbol()));
     os << formatv("            return \"{0}\"\n", enumCase.getStr());
   }
   os << formatv("        raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
-                enumAttr.getEnumClassName());
+                enumInfo.getEnumClassName());
   os << "\n";
 }
 
@@ -98,17 +99,21 @@ static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
 /// Emits an attribute builder for the given enum attribute to support automatic
 /// conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
-static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
+static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
+  std::optional<Attribute> enumAttrInfo = enumInfo.asEnumAttr();
+  if (!enumAttrInfo)
+    return false;
+
   int64_t bitwidth;
-  if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
+  if (extractUIntBitwidth(enumInfo.getUnderlyingType(), bitwidth)) {
     llvm::errs() << "failed to identify bitwidth of "
-                 << enumAttr.getUnderlyingType();
+                 << enumInfo.getUnderlyingType();
     return true;
   }
-
   os << formatv("@register_attribute_builder(\"{0}\")\n",
-                enumAttr.getAttrDefName());
-  os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
+                enumAttrInfo->getAttrDefName());
+  os << formatv("def _{0}(x, context):\n",
+                enumAttrInfo->getAttrDefName().lower());
   os << formatv("    return "
                 "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
                 "context=context), int(x))\n\n",
@@ -136,9 +141,9 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
   os << fileHeader;
   for (const Record *it :
        records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
-    EnumAttr enumAttr(*it);
-    emitEnumClass(enumAttr, os);
-    emitAttributeBuilder(enumAttr, os);
+    EnumInfo enumInfo(*it);
+    emitEnumClass(enumInfo, os);
+    emitAttributeBuilder(enumInfo, os);
   }
   for (const Record *it :
        records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index d11aa9b27c2d8..fa6fad156b747 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -12,6 +12,7 @@
 
 #include "FormatGen.h"
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "llvm/ADT/BitVector.h"
@@ -30,8 +31,8 @@ using llvm::Record;
 using llvm::RecordKeeper;
 using namespace mlir;
 using mlir::tblgen::Attribute;
-using mlir::tblgen::EnumAttr;
-using mlir::tblgen::EnumAttrCase;
+using mlir::tblgen::EnumCase;
+using mlir::tblgen::EnumInfo;
 using mlir::tblgen::FmtContext;
 using mlir::tblgen::tgfmt;
 
@@ -45,7 +46,7 @@ static std::string makeIdentifier(StringRef str) {
 
 static void emitEnumClass(const Record &enumDef, StringRef enumName,
                           StringRef underlyingType, StringRef description,
-                          const std::vector<EnumAttrCase> &enumerants,
+                          const std::vector<EnumCase> &enumerants,
                           raw_ostream &os) {
   os << "// " << description << "\n";
   os << "enum class " << enumName;
@@ -66,12 +67,13 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName,
   os << "};\n\n";
 }
 
-static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
+static void emitParserPrinter(const EnumInfo &enumInfo, StringRef qualName,
                               StringRef cppNamespace, raw_ostream &os) {
-  if (enumAttr.getUnderlyingType().empty() ||
-      enumAttr.getConstBuilderTemplate().empty())
+  std::optional<Attribute> enumAttrInfo = enumInfo.asEnumAttr();
+  if (enumInfo.getUnderlyingType().empty() ||
+      (enumAttrInfo && enumAttrInfo->getConstBuilderTemplate().empty()))
     return;
-  auto cases = enumAttr.getAllCases();
+  auto cases = enumInfo.getAllCases();
 
   // Check which cases shouldn't be printed using a keyword.
   llvm::BitVector nonKeywordCases(cases.size());
@@ -128,8 +130,9 @@ namespace llvm {
 inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
   auto valueStr = stringifyEnum(value);
 )";
+
   os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
-                enumAttr.getSummary());
+                enumInfo.getSummary());
 
   // If all cases require a string, always wrap.
   if (nonKeywordCases.all()) {
@@ -157,9 +160,9 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
 
     // If this is a bit enum, conservatively print the string form if the value
     // is not a power of two (i.e. not a single bit case) and not a known case.
-  } else if (enumAttr.isBitEnum()) {
+  } else if (enumInfo.isBitEnum()) {
     // Process the known multi-bit cases that use valid keywords.
-    SmallVector<EnumAttrCase *> validMultiBitCases;
+    SmallVector<EnumCase *> validMultiBitCases;
     for (auto [index, caseVal] : llvm::enumerate(cases)) {
       uint64_t value = caseVal.getValue();
       if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
@@ -167,7 +170,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
     }
     if (!validMultiBitCases.empty()) {
       os << "  switch (value) {\n";
-      for (EnumAttrCase *caseVal : validMultiBitCases) {
+      for (EnumCase *caseVal : validMultiBitCases) {
         StringRef symbol = caseVal->getSymbol();
         os << llvm::formatv("  case {0}::{1}:\n", qualName,
                             llvm::isDigit(symbol.front()) ? ("_" + symbol)
@@ -224,9 +227,9 @@ template<> struct DenseMapInfo<{0}> {{
 }
 
 static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef maxEnumValFnName = enumInfo.getMaxEnumValFnName();
+  auto enumerants = enumInfo.getAllCases();
 
   unsigned maxEnumVal = 0;
   for (const auto &enumerant : enumerants) {
@@ -245,10 +248,10 @@ static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
   os << "}\n\n";
 }
 
-// Returns the EnumAttrCase whose value is zero if exists; returns std::nullopt
+// Returns the EnumCase whose value is zero if exists; returns std::nullopt
 // otherwise.
-static std::optional<EnumAttrCase>
-getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
+static std::optional<EnumCase>
+getAllBitsUnsetCase(llvm::ArrayRef<EnumCase> cases) {
   for (auto attrCase : cases) {
     if (attrCase.getValue() == 0)
       return attrCase;
@@ -268,9 +271,9 @@ getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
 // inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
 // bool value=true);
 static void emitOperators(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
   int64_t validBits = enumDef.getValueAsInt("validBits");
   const char *const operators = R"(
 inline constexpr {0} operator|({0} a, {0} b) {{
@@ -303,11 +306,11 @@ inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true)
 }
 
 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
-  StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef symToStrFnName = enumInfo.getSymbolToStringFnName();
+  StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType();
+  auto enumerants = enumInfo.getAllCases();
 
   os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
                 symToStrFnRetType);
@@ -324,19 +327,19 @@ static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
 }
 
 static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
-  StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef symToStrFnName = enumInfo.getSymbolToStringFnName();
+  StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType();
   StringRef separator = enumDef.getValueAsString("separator");
-  auto enumerants = enumAttr.getAllCases();
+  auto enumerants = enumInfo.getAllCases();
   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
 
   os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
                 symToStrFnRetType);
 
   os << formatv("  auto val = static_cast<{0}>(symbol);\n",
-                enumAttr.getUnderlyingType());
+                enumInfo.getUnderlyingType());
   // If we have unknown bit set, return an empty string to signal errors.
   int64_t validBits = enumDef.getValueAsInt("validBits");
   os << formatv("  assert({0}u == ({0}u | val) && \"invalid bits set in bit "
@@ -365,21 +368,23 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
 )";
   // Optionally elide bits that are members of groups that will also be printed
   // for more concise output.
-  if (enumAttr.printBitEnumPrimaryGroups()) {
+  if (enumInfo.printBitEnumPrimaryGroups()) {
     os << "  // Print bit enum groups before individual bits\n";
     // Emit comparisons for group bit cases in reverse tablegen declaration
     // order, removing bits for groups with all bits present.
     for (const auto &enumerant : llvm::reverse(enumerants)) {
       if ((enumerant.getValue() != 0) &&
-          enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) {
+          (enumerant.getDef().isSubClassOf("BitEnumCaseGroup") ||
+           enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup"))) {
         os << formatv(formatCompareRemove, enumerant.getValue(),
-                      enumerant.getStr(), enumAttr.getUnderlyingType());
+                      enumerant.getStr(), enumInfo.getUnderlyingType());
       }
     }
     // Emit comparisons for individual bit cases in tablegen declaration order.
     for (const auto &enumerant : enumerants) {
       if ((enumerant.getValue() != 0) &&
-          enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit"))
+          (enumerant.getDef().isSubClassOf("BitEnumCaseBit") ||
+           enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit")))
         os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
     }
   } else {
@@ -396,10 +401,10 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
 }
 
 static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef strToSymFnName = enumInfo.getStringToSymbolFnName();
+  auto enumerants = enumInfo.getAllCases();
 
   os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
                 enumName, strToSymFnName);
@@ -416,13 +421,13 @@ static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
 }
 
 static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  StringRef strToSymFnName = enumInfo.getStringToSymbolFnName();
   StringRef separator = enumDef.getValueAsString("separator");
   StringRef separatorTrimmed = separator.trim();
-  auto enumerants = enumAttr.getAllCases();
+  auto enumerants = enumInfo.getAllCases();
   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
 
   os << formatv("::std::optional<{0}> {1}(::llvm::StringRef str) {{\n",
@@ -463,17 +468,16 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
 
 static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
                                             raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
+  auto enumerants = enumInfo.getAllCases();
 
   // Avoid generating the underlying value to symbol conversion function if
   // there is an enumerant without explicit value.
-  if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
-        return enumerant.getValue() < 0;
-      }))
+  if (llvm::any_of(enumerants,
+                   [](EnumCase enumerant) { return enumerant.getValue() < 0; }))
     return;
 
   os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
@@ -493,10 +497,10 @@ static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
 }
 
 static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
-  const Record *baseAttrDef = enumAttr.getBaseAttrClass();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef attrClassName = enumInfo.getSpecializedAttrClassName();
+  const Record *baseAttrDef = enumInfo.getBaseAttrClass();
   Attribute baseAttr(baseAttrDef);
 
   // Emit classof method
@@ -520,7 +524,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
   os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
                 attrClassName, enumName);
 
-  StringRef underlyingType = enumAttr.getUnderlyingType();
+  StringRef underlyingType = enumInfo.getUnderlyingType();
 
   // Assuming that it is IntegerAttr constraint
   int64_t bitwidth = 64;
@@ -552,11 +556,11 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
 
 static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
                                             raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
+  auto enumerants = enumInfo.getAllCases();
   auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
 
   os << formatv("::std::optional<{0}> {1}({2} value) {{\n", enumName,
@@ -574,16 +578,16 @@ static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
 }
 
 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  StringRef description = enumAttr.getSummary();
-  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
-  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
-  StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
-  StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  StringRef description = enumInfo.getSummary();
+  StringRef strToSymFnName = enumInfo.getStringToSymbolFnName();
+  StringRef symToStrFnName = enumInfo.getSymbolToStringFnName();
+  StringRef symToStrFnRetType = enumInfo.getSymbolToStringFnRetType();
+  StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName();
+  auto enumerants = enumInfo.getAllCases();
 
   SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
@@ -595,7 +599,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
   emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
 
   // Emit conversion function declarations
-  if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
+  if (llvm::all_of(enumerants, [](EnumCase enumerant) {
         return enumerant.getValue() >= 0;
       })) {
     os << formatv(
@@ -606,7 +610,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
   os << formatv("::std::optional<{0}> {1}(::llvm::StringRef);\n", enumName,
                 strToSymFnName);
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     emitOperators(enumDef, os);
   } else {
     emitMaxValueFn(enumDef, os);
@@ -644,8 +648,8 @@ class {1} : public ::mlir::{2} {
   {0} getValue() const;
 };
 )";
-  if (enumAttr.genSpecializedAttr()) {
-    StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
+  if (enumInfo.genSpecializedAttr()) {
+    StringRef attrClassName = enumInfo.getSpecializedAttrClassName();
     StringRef baseAttrClassName = "IntegerAttr";
     os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
   }
@@ -656,7 +660,7 @@ class {1} : public ::mlir::{2} {
   // Generate a generic parser and printer for the enum.
   std::string qualName =
       std::string(formatv("{0}::{1}", cppNamespace, enumName));
-  emitParserPrinter(enumAttr, qualName, cppNamespace, os);
+  emitParserPrinter(enumInfo, qualName, cppNamespace, os);
 
   // Emit DenseMapInfo for this enum class
   emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
@@ -673,8 +677,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
 }
 
 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  EnumInfo enumInfo(enumDef);
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
@@ -682,7 +686,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
   for (auto ns : namespaces)
     os << "namespace " << ns << " {\n";
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     emitSymToStrFnForBitEnum(enumDef, os);
     emitStrToSymFnForBitEnum(enumDef, os);
     emitUnderlyingToSymFnForBitEnum(enumDef, os);
@@ -692,7 +696,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
     emitUnderlyingToSymFnForIntEnum(enumDef, os);
   }
 
-  if (enumAttr.genSpecializedAttr())
+  if (enumInfo.genSpecializedAttr())
     emitSpecializedAttrDef(enumDef, os);
 
   for (auto ns : llvm::reverse(namespaces))
diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index 9e19f479d673a..96af14d36817b 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/TableGen/Argument.h"
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 
@@ -335,13 +336,13 @@ static bool emitOpMLIRBuilders(const RecordKeeper &records, raw_ostream &os) {
 
 namespace {
 // Wrapper class around a Tablegen definition of an LLVM enum attribute case.
-class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
+class LLVMEnumCase : public tblgen::EnumCase {
 public:
-  using tblgen::EnumAttrCase::EnumAttrCase;
+  using tblgen::EnumCase::EnumCase;
 
   // Constructs a case from a non LLVM-specific enum attribute case.
-  explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other)
-      : tblgen::EnumAttrCase(&other.getDef()) {}
+  explicit LLVMEnumCase(const tblgen::EnumCase &other)
+      : tblgen::EnumCase(&other.getDef()) {}
 
   // Returns the C++ enumerant for the LLVM API.
   StringRef getLLVMEnumerant() const {
@@ -350,9 +351,9 @@ class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
 };
 
 // Wraper class around a Tablegen definition of an LLVM enum attribute.
-class LLVMEnumAttr : public tblgen::EnumAttr {
+class LLVMEnumInfo : public tblgen::EnumInfo {
 public:
-  using tblgen::EnumAttr::EnumAttr;
+  using tblgen::EnumInfo::EnumInfo;
 
   // Returns the C++ enum name for the LLVM API.
   StringRef getLLVMClassName() const {
@@ -360,19 +361,19 @@ class LLVMEnumAttr : public tblgen::EnumAttr {
   }
 
   // Returns all associated cases viewed as LLVM-specific enum cases.
-  std::vector<LLVMEnumAttrCase> getAllCases() const {
-    std::vector<LLVMEnumAttrCase> cases;
+  std::vector<LLVMEnumCase> getAllCases() const {
+    std::vector<LLVMEnumCase> cases;
 
-    for (auto &c : tblgen::EnumAttr::getAllCases())
+    for (auto &c : tblgen::EnumInfo::getAllCases())
       cases.emplace_back(c);
 
     return cases;
   }
 
-  std::vector<LLVMEnumAttrCase> getAllUnsupportedCases() const {
+  std::vector<LLVMEnumCase> getAllUnsupportedCases() const {
     const auto *inits = def->getValueAsListInit("unsupported");
 
-    std::vector<LLVMEnumAttrCase> cases;
+    std::vector<LLVMEnumCase> cases;
     cases.reserve(inits->size());
 
     for (const llvm::Init *init : *inits)
@@ -383,9 +384,9 @@ class LLVMEnumAttr : public tblgen::EnumAttr {
 };
 
 // Wraper class around a Tablegen definition of a C-style LLVM enum attribute.
-class LLVMCEnumAttr : public tblgen::EnumAttr {
+class LLVMCEnumInfo : public tblgen::EnumInfo {
 public:
-  using tblgen::EnumAttr::EnumAttr;
+  using tblgen::EnumInfo::EnumInfo;
 
   // Returns the C++ enum name for the LLVM API.
   StringRef getLLVMClassName() const {
@@ -393,10 +394,10 @@ class LLVMCEnumAttr : public tblgen::EnumAttr {
   }
 
   // Returns all associated cases viewed as LLVM-specific enum cases.
-  std::vector<LLVMEnumAttrCase> getAllCases() const {
-    std::vector<LLVMEnumAttrCase> cases;
+  std::vector<LLVMEnumCase> getAllCases() const {
+    std::vector<LLVMEnumCase> cases;
 
-    for (auto &c : tblgen::EnumAttr::getAllCases())
+    for (auto &c : tblgen::EnumInfo::getAllCases())
       cases.emplace_back(c);
 
     return cases;
@@ -408,10 +409,10 @@ class LLVMCEnumAttr : public tblgen::EnumAttr {
 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
 // (Enum) to the corresponding LLVM API enumerant
 static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
-  LLVMEnumAttr enumAttr(record);
-  StringRef llvmClass = enumAttr.getLLVMClassName();
-  StringRef cppClassName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  LLVMEnumInfo enumInfo(record);
+  StringRef llvmClass = enumInfo.getLLVMClassName();
+  StringRef cppClassName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   // Emit the function converting the enum attribute to its LLVM counterpart.
   os << formatv(
@@ -419,7 +420,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
       llvmClass, cppClassName, cppNamespace);
   os << "  switch (value) {\n";
 
-  for (const auto &enumerant : enumAttr.getAllCases()) {
+  for (const auto &enumerant : enumInfo.getAllCases()) {
     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
     StringRef cppEnumerant = enumerant.getSymbol();
     os << formatv("  case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
@@ -429,7 +430,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
 
   os << "  }\n";
   os << formatv("  llvm_unreachable(\"unknown {0} type\");\n",
-                enumAttr.getEnumClassName());
+                enumInfo.getEnumClassName());
   os << "}\n\n";
 }
 
@@ -437,7 +438,7 @@ static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
 // (Enum) to the corresponding LLVM API C-style enumerant
 static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
-  LLVMCEnumAttr enumAttr(record);
+  LLVMCEnumInfo enumAttr(record);
   StringRef llvmClass = enumAttr.getLLVMClassName();
   StringRef cppClassName = enumAttr.getEnumClassName();
   StringRef cppNamespace = enumAttr.getCppNamespace();
@@ -467,10 +468,10 @@ static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
 // containing switch-based logic to convert from the LLVM API enumerant to MLIR
 // LLVM dialect enum attribute (Enum).
 static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
-  LLVMEnumAttr enumAttr(record);
-  StringRef llvmClass = enumAttr.getLLVMClassName();
-  StringRef cppClassName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  LLVMEnumInfo enumInfo(record);
+  StringRef llvmClass = enumInfo.getLLVMClassName();
+  StringRef cppClassName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   // Emit the function converting the enum attribute from its LLVM counterpart.
   os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} "
@@ -478,23 +479,23 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
                 cppNamespace, cppClassName, llvmClass);
   os << "  switch (value) {\n";
 
-  for (const auto &enumerant : enumAttr.getAllCases()) {
+  for (const auto &enumerant : enumInfo.getAllCases()) {
     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
     StringRef cppEnumerant = enumerant.getSymbol();
     os << formatv("  case {0}::{1}:\n", llvmClass, llvmEnumerant);
     os << formatv("    return {0}::{1}::{2};\n", cppNamespace, cppClassName,
                   cppEnumerant);
   }
-  for (const auto &enumerant : enumAttr.getAllUnsupportedCases()) {
+  for (const auto &enumerant : enumInfo.getAllUnsupportedCases()) {
     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
     os << formatv("  case {0}::{1}:\n", llvmClass, llvmEnumerant);
     os << formatv("    llvm_unreachable(\"unsupported case {0}::{1}\");\n",
-                  enumAttr.getLLVMClassName(), llvmEnumerant);
+                  enumInfo.getLLVMClassName(), llvmEnumerant);
   }
 
   os << "  }\n";
   os << formatv("  llvm_unreachable(\"unknown {0} type\");",
-                enumAttr.getLLVMClassName());
+                enumInfo.getLLVMClassName());
   os << "}\n\n";
 }
 
@@ -502,10 +503,10 @@ static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
 // containing switch-based logic to convert from the LLVM API C-style enumerant
 // to MLIR LLVM dialect enum attribute (Enum).
 static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
-  LLVMCEnumAttr enumAttr(record);
-  StringRef llvmClass = enumAttr.getLLVMClassName();
-  StringRef cppClassName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  LLVMCEnumInfo enumInfo(record);
+  StringRef llvmClass = enumInfo.getLLVMClassName();
+  StringRef cppClassName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   // Emit the function converting the enum attribute from its LLVM counterpart.
   os << formatv(
@@ -514,7 +515,7 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
       cppNamespace, cppClassName);
   os << "  switch (value) {\n";
 
-  for (const auto &enumerant : enumAttr.getAllCases()) {
+  for (const auto &enumerant : enumInfo.getAllCases()) {
     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
     StringRef cppEnumerant = enumerant.getSymbol();
     os << formatv("  case static_cast<int64_t>({0}::{1}):\n", llvmClass,
@@ -525,7 +526,7 @@ static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
 
   os << "  }\n";
   os << formatv("  llvm_unreachable(\"unknown {0} type\");",
-                enumAttr.getLLVMClassName());
+                enumInfo.getLLVMClassName());
   os << "}\n\n";
 }
 
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index dbaad84cda5d6..f53aebb302dc9 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Support/IndentedOstream.h"
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/DenseMap.h"
@@ -384,14 +385,14 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &records, raw_ostream &os,
 // Enum Documentation
 //===----------------------------------------------------------------------===//
 
-static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
+static void emitEnumDoc(const EnumInfo &def, raw_ostream &os) {
   os << formatv("\n### {0}\n", def.getEnumClassName());
 
   // Emit the summary if present.
   emitSummary(def.getSummary(), os);
 
   // Emit case documentation.
-  std::vector<EnumAttrCase> cases = def.getAllCases();
+  std::vector<EnumCase> cases = def.getAllCases();
   os << "\n#### Cases:\n\n";
   os << "| Symbol | Value | String |\n"
      << "| :----: | :---: | ------ |";
@@ -406,7 +407,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
 static void emitEnumDoc(const RecordKeeper &records, raw_ostream &os) {
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
   for (const Record *def : records.getAllDerivedDefinitions("EnumAttrInfo"))
-    emitEnumDoc(EnumAttr(def), os);
+    emitEnumDoc(EnumInfo(def), os);
 }
 
 //===----------------------------------------------------------------------===//
@@ -441,7 +442,7 @@ static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
 static void emitBlock(ArrayRef<Attribute> attributes, StringRef inputFilename,
                       ArrayRef<AttrDef> attrDefs, ArrayRef<OpDocGroup> ops,
                       ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
-                      ArrayRef<EnumAttr> enums, raw_ostream &os) {
+                      ArrayRef<EnumInfo> enums, raw_ostream &os) {
   if (!ops.empty()) {
     os << "\n## Operations\n";
     emitSourceLink(inputFilename, os);
@@ -490,7 +491,7 @@ static void emitBlock(ArrayRef<Attribute> attributes, StringRef inputFilename,
 
   if (!enums.empty()) {
     os << "\n## Enums\n";
-    for (const EnumAttr &def : enums)
+    for (const EnumInfo &def : enums)
       emitEnumDoc(def, os);
   }
 }
@@ -499,7 +500,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
                            ArrayRef<Attribute> attributes,
                            ArrayRef<AttrDef> attrDefs, ArrayRef<OpDocGroup> ops,
                            ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
-                           ArrayRef<EnumAttr> enums, raw_ostream &os) {
+                           ArrayRef<EnumInfo> enums, raw_ostream &os) {
   os << "\n# '" << dialect.getName() << "' Dialect\n";
   emitSummary(dialect.getSummary(), os);
   emitDescription(dialect.getDescription(), os);
@@ -532,7 +533,7 @@ static bool emitDialectDoc(const RecordKeeper &records, raw_ostream &os) {
   std::vector<OpDocGroup> dialectOps;
   std::vector<Type> dialectTypes;
   std::vector<TypeDef> dialectTypeDefs;
-  std::vector<EnumAttr> dialectEnums;
+  std::vector<EnumInfo> dialectEnums;
 
   SmallDenseSet<const Record *> seen;
   auto addIfNotSeen = [&](const Record *record, const auto &def, auto &vec) {
@@ -576,7 +577,7 @@ static bool emitDialectDoc(const RecordKeeper &records, raw_ostream &os) {
     addIfInDialect(def, Type(def), dialectTypes);
   dialectEnums.reserve(enumDefs.size());
   for (const Record *def : enumDefs)
-    addIfNotSeen(def, EnumAttr(def), dialectEnums);
+    addIfNotSeen(def, EnumInfo(def), dialectEnums);
 
   // Sort alphabetically ignorning dialect for ops and section name for
   // sections.
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index fe724e86d6707..3a7a7aaf3a5dd 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -11,6 +11,7 @@
 #include "OpClass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/TableGen/Class.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/Operator.h"
 #include "mlir/TableGen/Trait.h"
@@ -424,17 +425,17 @@ struct OperationFormat {
 //===----------------------------------------------------------------------===//
 // Parser Gen
 
-/// Returns true if we can format the given attribute as an EnumAttr in the
+/// Returns true if we can format the given attribute as an enum in the
 /// parser format.
 static bool canFormatEnumAttr(const NamedAttribute *attr) {
   Attribute baseAttr = attr->attr.getBaseAttr();
-  const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
-  if (!enumAttr)
+  if (!baseAttr.isEnumAttr())
     return false;
+  EnumInfo enumInfo(&baseAttr.getDef());
 
   // The attribute must have a valid underlying type and a constant builder.
-  return !enumAttr->getUnderlyingType().empty() &&
-         !enumAttr->getConstBuilderTemplate().empty();
+  return !enumInfo.getUnderlyingType().empty() &&
+         !baseAttr.getConstBuilderTemplate().empty();
 }
 
 /// Returns if we should format the given attribute as an SymbolNameAttr.
@@ -1150,21 +1151,21 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
                               FmtContext &attrTypeCtx, bool parseAsOptional,
                               bool useProperties, StringRef opCppClassName) {
   Attribute baseAttr = var->attr.getBaseAttr();
-  const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
-  std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
+  EnumInfo enumInfo(&baseAttr.getDef());
+  std::vector<EnumCase> cases = enumInfo.getAllCases();
 
   // Generate the code for building an attribute for this enum.
   std::string attrBuilderStr;
   {
     llvm::raw_string_ostream os(attrBuilderStr);
-    os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
+    os << tgfmt(baseAttr.getConstBuilderTemplate(), &attrTypeCtx,
                 "*attrOptional");
   }
 
   // Build a string containing the cases that can be formatted as a keyword.
   std::string validCaseKeywordsStr = "{";
   llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
-  for (const EnumAttrCase &attrCase : cases)
+  for (const EnumCase &attrCase : cases)
     if (canFormatStringAsKeyword(attrCase.getStr()))
       validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
   validCaseKeywordsOS.str().back() = '}';
@@ -1194,8 +1195,8 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
         formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name);
   }
 
-  body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
-                  enumAttr.getStringToSymbolFnName(), attrBuilderStr,
+  body << formatv(enumAttrParserCode, var->name, enumInfo.getCppNamespace(),
+                  enumInfo.getStringToSymbolFnName(), attrBuilderStr,
                   validCaseKeywordsStr, errorMessage, attrAssignment);
 }
 
@@ -2264,13 +2265,13 @@ static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
                                MethodBody &body) {
   Attribute baseAttr = var->attr.getBaseAttr();
-  const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
-  std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
+  const EnumInfo enumInfo(&baseAttr.getDef());
+  std::vector<EnumCase> cases = enumInfo.getAllCases();
 
   body << formatv(enumAttrBeginPrinterCode,
                   (var->attr.isOptional() ? "*" : "") +
                       op.getGetterName(var->name),
-                  enumAttr.getSymbolToStringFnName());
+                  enumInfo.getSymbolToStringFnName());
 
   // Get a string containing all of the cases that can't be represented with a
   // keyword.
@@ -2283,7 +2284,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
   // Otherwise if this is a bit enum attribute, don't allow cases that may
   // overlap with other cases. For simplicity sake, only allow cases with a
   // single bit value.
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     for (auto it : llvm::enumerate(cases)) {
       int64_t value = it.value().getValue();
       if (value < 0 || !llvm::isPowerOf2_64(value))
@@ -2295,8 +2296,8 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
   // case value to determine when to print in the string form.
   if (nonKeywordCases.any()) {
     body << "    switch (caseValue) {\n";
-    StringRef cppNamespace = enumAttr.getCppNamespace();
-    StringRef enumName = enumAttr.getEnumClassName();
+    StringRef cppNamespace = enumInfo.getCppNamespace();
+    StringRef enumName = enumInfo.getEnumClassName();
     for (auto it : llvm::enumerate(cases)) {
       if (nonKeywordCases.test(it.index()))
         continue;
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index c74cb9943671e..c8a12b9e21b90 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1377,12 +1377,12 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
     return handleConstantAttr(constAttr.getAttribute(),
                               constAttr.getConstantValue());
   }
-  if (leaf.isEnumAttrCase()) {
-    auto enumCase = leaf.getAsEnumAttrCase();
+  if (leaf.isEnumCase()) {
+    auto enumCase = leaf.getAsEnumCase();
     // This is an enum case backed by an IntegerAttr. We need to get its value
     // to build the constant.
     std::string val = std::to_string(enumCase.getValue());
-    return handleConstantAttr(enumCase, val);
+    return handleConstantAttr(Attribute(&enumCase.getDef()), val);
   }
 
   LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n");
@@ -1782,7 +1782,7 @@ void PatternEmitter::supplyValuesForOpArgs(
       auto leaf = node.getArgAsLeaf(argIndex);
       // The argument in the result DAG pattern.
       auto patArgName = node.getArgName(argIndex);
-      if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
+      if (leaf.isConstantAttr() || leaf.isEnumCase()) {
         // TODO: Refactor out into map to avoid recomputing these.
         if (!isa<NamedAttribute *>(opArg))
           PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 75b8829be4da5..7a6189c09f426 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
@@ -45,8 +46,8 @@ using llvm::SMLoc;
 using llvm::StringMap;
 using llvm::StringRef;
 using mlir::tblgen::Attribute;
-using mlir::tblgen::EnumAttr;
-using mlir::tblgen::EnumAttrCase;
+using mlir::tblgen::EnumCase;
+using mlir::tblgen::EnumInfo;
 using mlir::tblgen::NamedAttribute;
 using mlir::tblgen::NamedTypeConstraint;
 using mlir::tblgen::NamespaceEmitter;
@@ -335,18 +336,18 @@ static mlir::GenRegistration
 
 static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
                                             raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::vector<EnumCase> enumerants = enumInfo.getAllCases();
 
   // Mapping from availability class name to (enumerant, availability
   // specification) pairs.
-  llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
+  llvm::StringMap<llvm::SmallVector<std::pair<EnumCase, Availability>, 1>>
       classCaseMap;
 
   // Place all availability specifications to their corresponding
   // availability classes.
-  for (const EnumAttrCase &enumerant : enumerants)
+  for (const EnumCase &enumerant : enumerants)
     for (const Availability &avail : getAvailabilities(enumerant.getDef()))
       classCaseMap[avail.getClass()].push_back({enumerant, avail});
 
@@ -359,14 +360,14 @@ static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
 
     os << "  switch (value) {\n";
     for (const auto &caseSpecPair : classCasePair.getValue()) {
-      EnumAttrCase enumerant = caseSpecPair.first;
+      EnumCase enumerant = caseSpecPair.first;
       Availability avail = caseSpecPair.second;
       os << formatv("  case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
                     enumerant.getSymbol(), avail.getMergeInstancePreparation(),
                     avail.getMergeInstanceType(), avail.getMergeInstance());
     }
     // Only emit default if uncovered cases.
-    if (classCasePair.getValue().size() < enumAttr.getAllCases().size())
+    if (classCasePair.getValue().size() < enumInfo.getAllCases().size())
       os << "  default: break;\n";
     os << "  }\n"
        << "  return std::nullopt;\n"
@@ -376,19 +377,19 @@ static void emitAvailabilityQueryForIntEnum(const Record &enumDef,
 
 static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
                                             raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  std::vector<EnumAttrCase> enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  std::string underlyingType = std::string(enumInfo.getUnderlyingType());
+  std::vector<EnumCase> enumerants = enumInfo.getAllCases();
 
   // Mapping from availability class name to (enumerant, availability
   // specification) pairs.
-  llvm::StringMap<llvm::SmallVector<std::pair<EnumAttrCase, Availability>, 1>>
+  llvm::StringMap<llvm::SmallVector<std::pair<EnumCase, Availability>, 1>>
       classCaseMap;
 
   // Place all availability specifications to their corresponding
   // availability classes.
-  for (const EnumAttrCase &enumerant : enumerants)
+  for (const EnumCase &enumerant : enumerants)
     for (const Availability &avail : getAvailabilities(enumerant.getDef()))
       classCaseMap[avail.getClass()].push_back({enumerant, avail});
 
@@ -406,7 +407,7 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
 
     os << "  switch (value) {\n";
     for (const auto &caseSpecPair : classCasePair.getValue()) {
-      EnumAttrCase enumerant = caseSpecPair.first;
+      EnumCase enumerant = caseSpecPair.first;
       Availability avail = caseSpecPair.second;
       os << formatv("  case {0}::{1}: { {2} return {3}({4}); }\n", enumName,
                     enumerant.getSymbol(), avail.getMergeInstancePreparation(),
@@ -420,10 +421,10 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef,
 }
 
 static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef enumName = enumAttr.getEnumClassName();
-  StringRef cppNamespace = enumAttr.getCppNamespace();
-  auto enumerants = enumAttr.getAllCases();
+  EnumInfo enumInfo(enumDef);
+  StringRef enumName = enumInfo.getEnumClassName();
+  StringRef cppNamespace = enumInfo.getCppNamespace();
+  auto enumerants = enumInfo.getAllCases();
 
   llvm::SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
@@ -435,7 +436,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
 
   // Place all availability specifications to their corresponding
   // availability classes.
-  for (const EnumAttrCase &enumerant : enumerants)
+  for (const EnumCase &enumerant : enumerants)
     for (const Availability &avail : getAvailabilities(enumerant.getDef())) {
       StringRef className = avail.getClass();
       if (handledClasses.count(className))
@@ -462,8 +463,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
 }
 
 static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
-  EnumAttr enumAttr(enumDef);
-  StringRef cppNamespace = enumAttr.getCppNamespace();
+  EnumInfo enumInfo(enumDef);
+  StringRef cppNamespace = enumInfo.getCppNamespace();
 
   llvm::SmallVector<StringRef, 2> namespaces;
   llvm::SplitString(cppNamespace, namespaces, "::");
@@ -471,7 +472,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
   for (auto ns : namespaces)
     os << "namespace " << ns << " {\n";
 
-  if (enumAttr.isBitEnum()) {
+  if (enumInfo.isBitEnum()) {
     emitAvailabilityQueryForBitEnum(enumDef, os);
   } else {
     emitAvailabilityQueryForIntEnum(enumDef, os);
@@ -535,7 +536,7 @@ static void emitAttributeSerialization(const Attribute &attr,
   os << tabs
      << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
   if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
-    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    EnumInfo baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("  {0}.push_back(prepareConstantInt({1}.getLoc(), "
                   "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
@@ -544,7 +545,7 @@ static void emitAttributeSerialization(const Attribute &attr,
                   baseEnum.getEnumClassName());
   } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
              attr.isSubClassOf("SPIRV_I32EnumAttr")) {
-    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    EnumInfo baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("  {0}.push_back(static_cast<uint32_t>("
                   "::llvm::cast<{1}::{2}Attr>(attr).getValue()));\n",
@@ -831,7 +832,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
                                          StringRef words, StringRef wordIndex,
                                          raw_ostream &os) {
   if (llvm::is_contained(constantIdEnumAttrs, attr.getAttrDefName())) {
-    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    EnumInfo baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
@@ -840,7 +841,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
                   baseEnum.getEnumClassName(), words, wordIndex);
   } else if (attr.isSubClassOf("SPIRV_BitEnumAttr") ||
              attr.isSubClassOf("SPIRV_I32EnumAttr")) {
-    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    EnumInfo baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("  {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "opBuilder.getAttr<{2}::{3}Attr>("
@@ -1246,9 +1247,9 @@ static void emitEnumGetAttrNameFnDecl(raw_ostream &os) {
                 "attributeName();\n");
 }
 
-static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
+static void emitEnumGetAttrNameFnDefn(const EnumInfo &enumInfo,
                                       raw_ostream &os) {
-  auto enumName = enumAttr.getEnumClassName();
+  auto enumName = enumInfo.getEnumClassName();
   os << formatv("template <> inline StringRef attributeName<{0}>() {{\n",
                 enumName);
   os << "  "
@@ -1266,8 +1267,8 @@ static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) {
   os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
   emitEnumGetAttrNameFnDecl(os);
   for (const auto *def : defs) {
-    EnumAttr enumAttr(*def);
-    emitEnumGetAttrNameFnDefn(enumAttr, os);
+    EnumInfo enumInfo(*def);
+    emitEnumGetAttrNameFnDefn(enumInfo, os);
   }
   os << "#endif // MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H\n";
   return false;
@@ -1306,9 +1307,9 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
     if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
         !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
       continue;
-    EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
+    EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef("enum"));
 
-    for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
+    for (const EnumCase &enumerant : enumInfo.getAllCases())
       for (const Availability &caseAvail :
            getAvailabilities(enumerant.getDef()))
         availClasses.try_emplace(caseAvail.getClass(), caseAvail);
@@ -1348,14 +1349,14 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
       if (!namedAttr.attr.isSubClassOf("SPIRV_BitEnumAttr") &&
           !namedAttr.attr.isSubClassOf("SPIRV_I32EnumAttr"))
         continue;
-      EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
+      EnumInfo enumInfo(namedAttr.attr.getDef().getValueAsDef("enum"));
 
       // (enumerant, availability specification) pairs for this availability
       // class.
-      SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
+      SmallVector<std::pair<EnumCase, Availability>, 1> caseSpecs;
 
       // Collect all cases' availability specs.
-      for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
+      for (const EnumCase &enumerant : enumInfo.getAllCases())
         for (const Availability &caseAvail :
              getAvailabilities(enumerant.getDef()))
           if (availClassName == caseAvail.getClass())
@@ -1366,19 +1367,19 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
       if (caseSpecs.empty())
         continue;
 
-      if (enumAttr.isBitEnum()) {
+      if (enumInfo.isBitEnum()) {
         // For BitEnumAttr, we need to iterate over each bit to query its
         // availability spec.
         os << formatv("  for (unsigned i = 0; "
                       "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
-                      enumAttr.getUnderlyingType());
+                      enumInfo.getUnderlyingType());
         os << formatv("    {0}::{1} tblgen_attrVal = this->{2}() & "
                       "static_cast<{0}::{1}>(1 << i);\n",
-                      enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
+                      enumInfo.getCppNamespace(), enumInfo.getEnumClassName(),
                       srcOp.getGetterName(namedAttr.name));
         os << formatv(
             "    if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
-            enumAttr.getUnderlyingType());
+            enumInfo.getUnderlyingType());
       } else {
         // For IntEnumAttr, we just need to query the value as a whole.
         os << "  {\n";
@@ -1386,7 +1387,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
                       srcOp.getGetterName(namedAttr.name));
       }
       os << formatv("    auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
-                    enumAttr.getCppNamespace(), avail.getQueryFnName());
+                    enumInfo.getCppNamespace(), avail.getQueryFnName());
       os << "    if (tblgen_instance) "
          // TODO` here once ODS supports
          // dialect-specific contents so that we can use not implementing the
@@ -1434,14 +1435,14 @@ static bool emitCapabilityImplication(const RecordKeeper &records,
                                       raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Capability Implication", os, records);
 
-  EnumAttr enumAttr(
+  EnumInfo enumInfo(
       records.getDef("SPIRV_CapabilityAttr")->getValueAsDef("enum"));
 
   os << "ArrayRef<spirv::Capability> "
         "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
      << "  switch (cap) {\n"
      << "  default: return {};\n";
-  for (const EnumAttrCase &enumerant : enumAttr.getAllCases()) {
+  for (const EnumCase &enumerant : enumInfo.getAllCases()) {
     const Record &def = enumerant.getDef();
     if (!def.getValue("implies"))
       continue;
@@ -1452,7 +1453,7 @@ static bool emitCapabilityImplication(const RecordKeeper &records,
        << ": {static const spirv::Capability implies[" << impliedCapsDefs.size()
        << "] = {";
     llvm::interleaveComma(impliedCapsDefs, os, [&](const Record *capDef) {
-      os << "spirv::Capability::" << EnumAttrCase(capDef).getSymbol();
+      os << "spirv::Capability::" << EnumCase(capDef).getSymbol();
     });
     os << "}; return ArrayRef<spirv::Capability>(implies, "
        << impliedCapsDefs.size() << "); }\n";
diff --git a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp
index 491f9143edb02..ddc149810ebd8 100644
--- a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/CodeGenHelpers.h"
+#include "mlir/TableGen/EnumInfo.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
@@ -42,8 +43,8 @@ using llvm::SMLoc;
 using llvm::StringMap;
 using llvm::StringRef;
 using mlir::tblgen::Attribute;
-using mlir::tblgen::EnumAttr;
-using mlir::tblgen::EnumAttrCase;
+using mlir::tblgen::EnumCase;
+using mlir::tblgen::EnumInfo;
 using mlir::tblgen::NamedAttribute;
 using mlir::tblgen::NamedTypeConstraint;
 using mlir::tblgen::NamespaceEmitter;

>From 509fe555aabf58cca9fefc8d13ef34af9e81a627 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Thu, 20 Mar 2025 23:06:17 -0700
Subject: [PATCH 2/2] [mlir] Decouple enum generation from attributes, adding
 EnumInfo and EnumCase

This commit pulls apart the inherent attribute dependence of classes
like EnumAttrInfo and EnumAttrCase, factoring them out into simpler
EnumCase and EnumInfo variants. This allows specifying the cases of an
enum without needing to make the cases, or the EnumInfo itself, a
subclass of SignlessIntegerAttrBase.

The existing classes are retained as subclasses of the new ones, both
for backwards compatibility and to allow attribute-specific
information.

In addition, the new BitEnum class changes its default printer/parser
behavior: cases when multiple keywords appear, like having both nuw
and nsw in overflow flags, will no longer be quoted by the operator<<,
and the FieldParser instance will now expect multiple keywords. All
instances of BitEnumAttr retain the old behavior.
---
 mlir/include/mlir/IR/EnumAttr.td              | 271 +++++++++++++-----
 mlir/include/mlir/TableGen/EnumInfo.h         |   4 +
 mlir/lib/TableGen/EnumInfo.cpp                |  16 +-
 mlir/lib/TableGen/Pattern.cpp                 |   2 +-
 mlir/test/Dialect/LLVMIR/func.mlir            |   2 +-
 mlir/test/IR/attribute.mlir                   |   2 +-
 mlir/test/lib/Dialect/Test/TestEnumDefs.td    |  25 +-
 mlir/test/mlir-tblgen/enums-gen.td            |  43 ++-
 .../mlir-tblgen/EnumPythonBindingGen.cpp      |  20 +-
 mlir/tools/mlir-tblgen/EnumsGen.cpp           | 124 +++++++-
 mlir/tools/mlir-tblgen/OpDocGen.cpp           |   4 +-
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |   6 +-
 mlir/utils/spirv/gen_spirv_dialect.py         |   4 +-
 13 files changed, 389 insertions(+), 134 deletions(-)

diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 9fec28f03ec28..e5406546b1950 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -14,8 +14,8 @@ include "mlir/IR/AttrTypeBase.td"
 //===----------------------------------------------------------------------===//
 // Enum attribute kinds
 
-// Additional information for an enum attribute case.
-class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
+// Additional information for an enum case.
+class EnumCase<string sym, int intVal, string strVal, int widthVal> {
   // The C++ enumerant symbol.
   string symbol = sym;
 
@@ -26,29 +26,56 @@ class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
 
   // The string representation of the enumerant. May be the same as symbol.
   string str = strVal;
+
+  // The bitwidth of the enum.
+  int width = widthVal;
 }
 
 // An enum attribute case stored with IntegerAttr, which has an integer value,
 // its representation as a string and a C++ symbol name which may be different.
+// Not needed when using the newer `EnumCase` form for defining enum cases.
 class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
-    EnumAttrCaseInfo<sym, intVal, strVal>,
+    EnumCase<sym, intVal, strVal, intType.bitwidth>,
     SignlessIntegerAttrBase<intType, "case " # strVal> {
   let predicate =
     CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() == " # intVal>;
 }
 
-// Cases of integer enum attributes with a specific type. By default, the string
+// Cases of integer enums with a specific type. By default, the string
 // representation is the same as the C++ symbol name.
+class I32EnumCase<string sym, int val, string str = sym>
+  : EnumCase<sym, val, str, 32>;
+class I64EnumCase<string sym, int val, string str = sym>
+  : EnumCase<sym, val, str, 64>;
+
+// Cases of integer enum attributes with a specific type. By default, the string
+// representation is the same as the C++ symbol name. These forms
+// are not needed when using the newer `EnumCase` form.
 class I32EnumAttrCase<string sym, int val, string str = sym>
     : IntEnumAttrCaseBase<I32, sym, str, val>;
 class I64EnumAttrCase<string sym, int val, string str = sym>
     : IntEnumAttrCaseBase<I64, sym, str, val>;
 
-// A bit enum case stored with an IntegerAttr. `val` here is *not* the ordinal
-// number of a bit that is set. It is an integer value with bits set to match
-// the case.
+// A bit enum case. `val` here is *not* the ordinal number of a bit
+// that is set. It is an integer value with bits set to match the case.
+class BitEnumCaseBase<string sym, int val, string str, int width> :
+    EnumCase<sym, val, str, width>;
+// Bit enum attr cases. The string representation is the same as the C++ symbol
+// name unless otherwise specified.
+class I8BitEnumCase<string sym, int val, string str = sym>
+  : BitEnumCaseBase<sym, val, str, 8>;
+class I16BitEnumCase<string sym, int val, string str = sym>
+  : BitEnumCaseBase<sym, val, str, 16>;
+class I32BitEnumCase<string sym, int val, string str = sym>
+  : BitEnumCaseBase<sym, val, str, 32>;
+class I64BitEnumCase<string sym, int val, string str = sym>
+  : BitEnumCaseBase<sym, val, str, 64>;
+
+// A form of `BitEnumCaseBase` that also inherits from `Attr` and encodes
+// the width of the enum, which was defined when enums were always
+// stored in attributes.
 class BitEnumAttrCaseBase<I intType, string sym, int val, string str = sym> :
-    EnumAttrCaseInfo<sym, val, str>,
+    BitEnumCaseBase<sym, val, str, intType.bitwidth>,
     SignlessIntegerAttrBase<intType, "case " #str>;
 
 class I8BitEnumAttrCase<string sym, int val, string str = sym>
@@ -61,6 +88,19 @@ class I64BitEnumAttrCase<string sym, int val, string str = sym>
     : BitEnumAttrCaseBase<I64, sym, val, str>;
 
 // The special bit enum case with no bits set (i.e. value = 0).
+class BitEnumCaseNone<string sym, string str, int width>
+    : BitEnumCaseBase<sym, 0, str, width>;
+
+class I8BitEnumCaseNone<string sym, string str = sym>
+  : BitEnumCaseNone<sym, str, 8>;
+class I16BitEnumCaseNone<string sym, string str = sym>
+  : BitEnumCaseNone<sym, str, 16>;
+class I32BitEnumCaseNone<string sym, string str = sym>
+  : BitEnumCaseNone<sym, str, 32>;
+class I64BitEnumCaseNone<string sym, string str = sym>
+  : BitEnumCaseNone<sym, str, 64>;
+
+// Older forms, used when enums were necessarily attributes.
 class I8BitEnumAttrCaseNone<string sym, string str = sym>
     : I8BitEnumAttrCase<sym, 0, str>;
 class I16BitEnumAttrCaseNone<string sym, string str = sym>
@@ -70,6 +110,24 @@ class I32BitEnumAttrCaseNone<string sym, string str = sym>
 class I64BitEnumAttrCaseNone<string sym, string str = sym>
     : I64BitEnumAttrCase<sym, 0, str>;
 
+// A bit enum case for a single bit, specified by a bit position `pos`.
+// The `pos` argument refers to the index of the bit, and is limited
+// to be in the range [0, width).
+class BitEnumCaseBit<string sym, int pos, string str, int width>
+    : BitEnumCaseBase<sym, !shl(1, pos), str, width> {
+  assert !and(!ge(pos, 0), !lt(pos, width)),
+      "bit position larger than underlying storage";
+}
+
+class I8BitEnumCaseBit<string sym, int pos, string str = sym>
+    : BitEnumCaseBit<sym, pos, str, 8>;
+class I16BitEnumCaseBit<string sym, int pos, string str = sym>
+    : BitEnumCaseBit<sym, pos, str, 16>;
+class I32BitEnumCaseBit<string sym, int pos, string str = sym>
+    : BitEnumCaseBit<sym, pos, str, 32>;
+class I64BitEnumCaseBit<string sym, int pos, string str = sym>
+    : BitEnumCaseBit<sym, pos, str, 64>;
+
 // A bit enum case for a single bit, specified by a bit position.
 // The pos argument refers to the index of the bit, and is limited
 // to be in the range [0, bitwidth).
@@ -90,12 +148,17 @@ class I64BitEnumAttrCaseBit<string sym, int pos, string str = sym>
 
 // A bit enum case for a group/list of previously declared cases, providing
 // a convenient alias for that group.
+class BitEnumCaseGroup<string sym, list<BitEnumCaseBase> cases, string str = sym>
+    : BitEnumCaseBase<sym,
+      !foldl(0, cases, value, bitcase, !or(value, bitcase.value)),
+      str, !head(cases).width>;
+
+// The attribute-only form of `BitEnumCaseGroup`.
 class BitEnumAttrCaseGroup<I intType, string sym,
-                           list<BitEnumAttrCaseBase> cases, string str = sym>
+                           list<BitEnumCaseBase> cases, string str = sym>
     : BitEnumAttrCaseBase<intType, sym,
           !foldl(0, cases, value, bitcase, !or(value, bitcase.value)),
           str>;
-
 class I8BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
                               string str = sym>
     : BitEnumAttrCaseGroup<I8, sym, cases, str>;
@@ -109,29 +172,36 @@ class I64BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
                               string str = sym>
     : BitEnumAttrCaseGroup<I64, sym, cases, str>;
 
-// Additional information for an enum attribute.
-class EnumAttrInfo<
-    string name, list<EnumAttrCaseInfo> cases, Attr baseClass> :
-      Attr<baseClass.predicate, baseClass.summary> {
-
+// Information describing an enum and the functions that should be generated for it.
+class EnumInfo<string name, string summaryValue, list<EnumCase> cases, int width> {
+  string summary = summaryValue;
   // Generate a description of this enums members for the MLIR docs.
-  let description =
+  string description =
         "Enum cases:\n" # !interleave(
           !foreach(case, cases,
               "* " # case.str  # " (`" # case.symbol # "`)"), "\n");
 
+  // The C++ namespace for this enum
+  string cppNamespace = "";
+
   // The C++ enum class name
   string className = name;
 
+  // C++ type wrapped by attribute
+  string cppType = cppNamespace # "::" # className;
+
   // List of all accepted cases
-  list<EnumAttrCaseInfo> enumerants = cases;
+  list<EnumCase> enumerants = cases;
 
   // The following fields are only used by the EnumsGen backend to generate
   // an enum class definition and conversion utility functions.
 
+  // The bitwidth underlying the class
+  int bitwidth = width;
+
   // The underlying type for the C++ enum class. An empty string mean the
   // underlying type is not explicitly specified.
-  string underlyingType = "";
+  string underlyingType = "uint" # width # "_t";
 
   // The name of the utility function that converts a value of the underlying
   // type to the corresponding symbol. It will have the following signature:
@@ -165,6 +235,15 @@ class EnumAttrInfo<
   // static constexpr unsigned <fn-name>();
   // ```
   string maxEnumValFnName = "getMaxEnumValFor" # name;
+}
+
+// A wrapper around `EnumInfo` that also makes the Enum an attribute
+// if `genSeecializedAttr` is 1 (though `EnumAttr` is the preferred means
+// to accomplish this) or declares that the enum will be stored in an attribute.
+class EnumAttrInfo<
+    string name, list<EnumCase> cases, SignlessIntegerAttrBase baseClass> :
+      EnumInfo<name, baseClass.summary, cases, !cast<I>(baseClass.valueType).bitwidth>,
+      Attr<baseClass.predicate, baseClass.summary> {
 
   // Generate specialized Attribute class
   bit genSpecializedAttr = 1;
@@ -188,15 +267,25 @@ class EnumAttrInfo<
     baseAttrClass.constBuilderCall);
   let valueType = baseAttrClass.valueType;
 
-  // C++ type wrapped by attribute
-  string cppType = cppNamespace # "::" # className;
-
   // Parser and printer code used by the EnumParameter class, to be provided by
   // derived classes
   string parameterParser = ?;
   string parameterPrinter = ?;
 }
 
+// An attribute holding a single integer value.
+class IntEnum<string name, string summary, list<EnumCase> cases, int width>
+    : EnumInfo<name,
+      !if(!empty(summary), "allowed i" # width # " cases: " #
+          !interleave(!foreach(case, cases, case.value), ", "),
+          summary),
+      cases, width>;
+
+class I32Enum<string name, string summary, list<EnumCase> cases>
+    : IntEnum<name, summary, cases, 32>;
+class I64Enum<string name, string summary, list<EnumCase> cases>
+    : IntEnum<name, summary, cases, 32>;
+
 // An enum attribute backed by IntegerAttr.
 //
 // Op attributes of this kind are stored as IntegerAttr. Extra verification will
@@ -245,13 +334,73 @@ class I64EnumAttr<string name, string summary, list<I64EnumAttrCase> cases> :
   let underlyingType = "uint64_t";
 }
 
+// The base mixin for bit enums that are stored as an integer.
+// This is used by both BitEnum and BitEnumAttr, which need to have a set of
+// extra properties that bit enums have which normal enums don't. However,
+// we can't just use BitEnum as a base class of BitEnumAttr, since BitEnumAttr
+// also inherits from EnumAttrInfo, causing double inheritance of EnumInfo.
+class BitEnumBase<list<BitEnumCaseBase> cases> {
+  // Determine "valid" bits from enum cases for error checking
+  int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value));
+
+  // The delimiter used to separate bit enum cases in strings. Only "|" and
+  // "," (along with optional spaces) are supported due to the use of the
+  // parseSeparatorFn in parameterParser below.
+  // Spaces in the separator string are used for printing, but will be optional
+  // for parsing.
+  string separator = "|";
+  assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)),
+      "separator must contain '|' or ',' for parameter parsing";
+
+  // Print the "primary group" only for bits that are members of case groups
+  // that have all bits present. When the value is 0, printing will display both
+  // both individual bit case names AND the names for all groups that the bit is
+  // contained in. When the value is 1, for each bit that is set AND is a member
+  // of a group with all bits set, only the "primary group" (i.e. the first
+  // group with all bits set in reverse declaration order) will be printed (for
+  // conciseness).
+  bit printBitEnumPrimaryGroups = 0;
+
+  // 1 if the operator<< for this enum should put quotes around values with
+  // multiple entries. Off by default in the general case but on for BitEnumAttrs
+  // since that was the original behavior.
+  bit printBitEnumQuoted = 0;
+}
+
+// A bit enum stored as an integer.
+//
+// Enums of these kind are staored as an integer. Attributes or properties deriving
+// from this enum will have additional verification generated on them to make sure
+// only allowed bits are set. Helper methods are generated to parse a sring of enum
+// values generated by the specified separator to a symbol and vice versa.
+class BitEnum<string name, string summary, list<BitEnumCaseBase> cases, int width>
+    : EnumInfo<name, summary, cases, width>, BitEnumBase<cases> {
+  // We need to return a string because we may concatenate symbols for multiple
+  // bits together.
+  let symbolToStringFnRetType = "std::string";
+}
+
+class I8BitEnum<string name, string summary,
+                     list<BitEnumCaseBase> cases>
+    : BitEnum<name, summary, cases, 8>;
+class I16BitEnum<string name, string summary,
+                     list<BitEnumCaseBase> cases>
+    : BitEnum<name, summary, cases, 16>;
+class I32BitEnum<string name, string summary,
+                     list<BitEnumCaseBase> cases>
+    : BitEnum<name, summary, cases, 32>;
+
+class I64BitEnum<string name, string summary,
+                     list<BitEnumCaseBase> cases>
+    : BitEnum<name, summary, cases, 64>;
+
 // A bit enum stored with an IntegerAttr.
 //
 // Op attributes of this kind are stored as IntegerAttr. Extra verification will
 // be generated on the integer to make sure only allowed bits are set. Besides,
 // helper methods are generated to parse a string separated with a specified
 // delimiter to a symbol and vice versa.
-class BitEnumAttrBase<I intType, list<BitEnumAttrCaseBase> cases,
+class BitEnumAttrBase<I intType, list<BitEnumCaseBase> cases,
                       string summary>
     : SignlessIntegerAttrBase<intType, summary> {
   let predicate = And<[
@@ -264,24 +413,13 @@ class BitEnumAttrBase<I intType, list<BitEnumAttrCaseBase> cases,
 }
 
 class BitEnumAttr<I intType, string name, string summary,
-                  list<BitEnumAttrCaseBase> cases>
-    : EnumAttrInfo<name, cases, BitEnumAttrBase<intType, cases, summary>> {
-  // Determine "valid" bits from enum cases for error checking
-  int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value));
-
+                  list<BitEnumCaseBase> cases>
+    : EnumAttrInfo<name, cases, BitEnumAttrBase<intType, cases, summary>>,
+      BitEnumBase<cases> {
   // We need to return a string because we may concatenate symbols for multiple
   // bits together.
   let symbolToStringFnRetType = "std::string";
 
-  // The delimiter used to separate bit enum cases in strings. Only "|" and
-  // "," (along with optional spaces) are supported due to the use of the
-  // parseSeparatorFn in parameterParser below.
-  // Spaces in the separator string are used for printing, but will be optional
-  // for parsing.
-  string separator = "|";
-  assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)),
-      "separator must contain '|' or ',' for parameter parsing";
-
   // Parsing function that corresponds to the enum separator. Only
   // "," and "|" are supported by this definition.
   string parseSeparatorFn = !if(!ge(!find(separator, "|"), 0),
@@ -312,36 +450,30 @@ class BitEnumAttr<I intType, string name, string summary,
   // Print the enum by calling `symbolToString`.
   let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)";
 
-  // Print the "primary group" only for bits that are members of case groups
-  // that have all bits present. When the value is 0, printing will display both
-  // both individual bit case names AND the names for all groups that the bit is
-  // contained in. When the value is 1, for each bit that is set AND is a member
-  // of a group with all bits set, only the "primary group" (i.e. the first
-  // group with all bits set in reverse declaration order) will be printed (for
-  // conciseness).
-  bit printBitEnumPrimaryGroups = 0;
+  // Use old-style operator<< and FieldParser for compatibility
+  let printBitEnumQuoted = 1;
 }
 
 class I8BitEnumAttr<string name, string summary,
-                     list<BitEnumAttrCaseBase> cases>
+                     list<BitEnumCaseBase> cases>
     : BitEnumAttr<I8, name, summary, cases> {
   let underlyingType = "uint8_t";
 }
 
 class I16BitEnumAttr<string name, string summary,
-                     list<BitEnumAttrCaseBase> cases>
+                     list<BitEnumCaseBase> cases>
     : BitEnumAttr<I16, name, summary, cases> {
   let underlyingType = "uint16_t";
 }
 
 class I32BitEnumAttr<string name, string summary,
-                     list<BitEnumAttrCaseBase> cases>
+                     list<BitEnumCaseBase> cases>
     : BitEnumAttr<I32, name, summary, cases> {
   let underlyingType = "uint32_t";
 }
 
 class I64BitEnumAttr<string name, string summary,
-                     list<BitEnumAttrCaseBase> cases>
+                     list<BitEnumCaseBase> cases>
     : BitEnumAttr<I64, name, summary, cases> {
   let underlyingType = "uint64_t";
 }
@@ -349,11 +481,13 @@ class I64BitEnumAttr<string name, string summary,
 // A C++ enum as an attribute parameter. The parameter implements a parser and
 // printer for the enum by dispatching calls to `stringToSymbol` and
 // `symbolToString`.
-class EnumParameter<EnumAttrInfo enumInfo>
+class EnumParameter<EnumInfo enumInfo>
     : AttrParameter<enumInfo.cppNamespace # "::" # enumInfo.className,
                     "an enum of type " # enumInfo.className> {
-  let parser = enumInfo.parameterParser;
-  let printer = enumInfo.parameterPrinter;
+  let parser = !if(!isa<EnumAttrInfo>(enumInfo),
+    !cast<EnumAttrInfo>(enumInfo).parameterParser, ?);
+  let printer = !if(!isa<EnumAttrInfo>(enumInfo),
+    !cast<EnumAttrInfo>(enumInfo).parameterPrinter, ?);
 }
 
 // An attribute backed by a C++ enum. The attribute contains a single
@@ -384,14 +518,14 @@ class EnumParameter<EnumAttrInfo enumInfo>
 // The op will appear in the IR as `my_dialect.my_op first`. However, the
 // generic format of the attribute will be `#my_dialect<"enum first">`. Override
 // the attribute's assembly format as required.
-class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
+class EnumAttr<Dialect dialect, EnumInfo enumInfo, string name = "",
                list <Trait> traits = []>
     : AttrDef<dialect, enumInfo.className, traits> {
   let summary = enumInfo.summary;
   let description = enumInfo.description;
 
   // The backing enumeration.
-  EnumAttrInfo enum = enumInfo;
+  EnumInfo enum = enumInfo;
 
   // Inherit the C++ namespace from the enum.
   let cppNamespace = enumInfo.cppNamespace;
@@ -417,41 +551,42 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
   let assemblyFormat = "$value";
 }
 
-class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> {
+class _symbolToValue<EnumInfo enumInfo, string case> {
   defvar cases =
-    !filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case));
+    !filter(iter, enumInfo.enumerants, !eq(iter.str, case));
 
   assert !not(!empty(cases)), "failed to find enum-case '" # case # "'";
 
   // `!empty` check to not cause an error if the cases are empty.
   // The assertion catches the issue later and emits a proper error message.
-  string value = enumAttrInfo.cppType # "::"
+  string value = enumInfo.cppType # "::"
     # !if(!empty(cases), "", !head(cases).symbol);
 }
 
-class _bitSymbolsToValue<BitEnumAttr bitEnumAttr, string case> {
+class _bitSymbolsToValue<EnumInfo bitEnum, string case> {
+  assert !isa<BitEnumBase>(bitEnum), "_bitSymbolsToValue not given a bit enum";
   defvar pos = !find(case, "|");
 
   // Recursive instantiation looking up the symbol before the `|` in
   // enum cases.
   string value = !if(
-    !eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnumAttr, case>.value,
-    /*rec=*/_symbolToValue<bitEnumAttr, !substr(case, 0, pos)>.value # "|"
-    # _bitSymbolsToValue<bitEnumAttr, !substr(case, !add(pos, 1))>.value
+    !eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnum, case>.value,
+    /*rec=*/_symbolToValue<bitEnum, !substr(case, 0, pos)>.value # "|"
+    # _bitSymbolsToValue<bitEnum, !substr(case, !add(pos, 1))>.value
   );
 }
 
 class ConstantEnumCaseBase<Attr attribute,
-    EnumAttrInfo enumAttrInfo, string case>
+    EnumInfo enumInfo, string case>
   : ConstantAttr<attribute,
-  !if(!isa<BitEnumAttr>(enumAttrInfo),
-    _bitSymbolsToValue<!cast<BitEnumAttr>(enumAttrInfo), case>.value,
-    _symbolToValue<enumAttrInfo, case>.value
+  !if(!isa<BitEnumBase>(enumInfo),
+    _bitSymbolsToValue<enumInfo, case>.value,
+    _symbolToValue<enumInfo, case>.value
   )
 >;
 
 /// Attribute constraint matching a constant enum case. `attribute` should be
-/// one of `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation
+/// one of `EnumInfo` or `EnumAttr` and `symbol` the string representation
 /// of an enum case. Multiple enum values of a bit-enum can be combined using
 /// `|` as a separator. Note that there mustn't be any whitespace around the
 /// separator.
@@ -463,10 +598,10 @@ class ConstantEnumCaseBase<Attr attribute,
 /// * ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">
 class ConstantEnumCase<Attr attribute, string case>
   : ConstantEnumCaseBase<attribute,
-    !if(!isa<EnumAttrInfo>(attribute), !cast<EnumAttrInfo>(attribute),
+    !if(!isa<EnumInfo>(attribute), !cast<EnumInfo>(attribute),
           !cast<EnumAttr>(attribute).enum), case> {
-  assert !or(!isa<EnumAttr>(attribute), !isa<EnumAttrInfo>(attribute)),
-    "attribute must be one of 'EnumAttr' or 'EnumAttrInfo'";
+  assert !or(!isa<EnumAttr>(attribute), !isa<EnumInfo>(attribute)),
+    "attribute must be one of 'EnumAttr' or 'EnumInfo'";
 }
 
 #endif // ENUMATTR_TD
diff --git a/mlir/include/mlir/TableGen/EnumInfo.h b/mlir/include/mlir/TableGen/EnumInfo.h
index 196267864f325..ece5154c0a285 100644
--- a/mlir/include/mlir/TableGen/EnumInfo.h
+++ b/mlir/include/mlir/TableGen/EnumInfo.h
@@ -85,6 +85,9 @@ class EnumInfo {
   // Returns the description of the enum.
   StringRef getDescription() const;
 
+  // Returns the bitwidth of the enum.
+  int64_t getBitwidth() const;
+
   // Returns the underlying type.
   StringRef getUnderlyingType() const;
 
@@ -120,6 +123,7 @@ class EnumInfo {
   // Only applicable for bit enums.
 
   bool printBitEnumPrimaryGroups() const;
+  bool printBitEnumQuoted() const;
 
   // Returns the TableGen definition this EnumAttrCase was constructed from.
   const llvm::Record &getDef() const;
diff --git a/mlir/lib/TableGen/EnumInfo.cpp b/mlir/lib/TableGen/EnumInfo.cpp
index 9f491d30f0e7f..6128c53557cc4 100644
--- a/mlir/lib/TableGen/EnumInfo.cpp
+++ b/mlir/lib/TableGen/EnumInfo.cpp
@@ -18,8 +18,8 @@ using llvm::Init;
 using llvm::Record;
 
 EnumCase::EnumCase(const Record *record) : def(record) {
-  assert(def->isSubClassOf("EnumAttrCaseInfo") &&
-         "must be subclass of TableGen 'EnumAttrCaseInfo' class");
+  assert(def->isSubClassOf("EnumCase") &&
+         "must be subclass of TableGen 'EnumCase' class");
 }
 
 EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {}
@@ -35,8 +35,8 @@ int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); }
 const Record &EnumCase::getDef() const { return *def; }
 
 EnumInfo::EnumInfo(const Record *record) : def(record) {
-  assert(isSubClassOf("EnumAttrInfo") &&
-         "must be subclass of TableGen 'EnumAttrInfo' class");
+  assert(isSubClassOf("EnumInfo") &&
+         "must be subclass of TableGen 'EnumInfo' class");
 }
 
 EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {}
@@ -55,7 +55,7 @@ std::optional<Attribute> EnumInfo::asEnumAttr() const {
   return std::nullopt;
 }
 
-bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
+bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumBase"); }
 
 StringRef EnumInfo::getEnumClassName() const {
   return def->getValueAsString("className");
@@ -73,6 +73,8 @@ StringRef EnumInfo::getCppNamespace() const {
   return def->getValueAsString("cppNamespace");
 }
 
+int64_t EnumInfo::getBitwidth() const { return def->getValueAsInt("bitwidth"); }
+
 StringRef EnumInfo::getUnderlyingType() const {
   return def->getValueAsString("underlyingType");
 }
@@ -127,4 +129,8 @@ bool EnumInfo::printBitEnumPrimaryGroups() const {
   return def->getValueAsBit("printBitEnumPrimaryGroups");
 }
 
+bool EnumInfo::printBitEnumQuoted() const {
+  return def->getValueAsBit("printBitEnumQuoted");
+}
+
 const Record &EnumInfo::getDef() const { return *def; }
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 73e2803c21dae..d83df3e415c36 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -57,7 +57,7 @@ bool DagLeaf::isNativeCodeCall() const {
 
 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
 
-bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); }
+bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); }
 
 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
 
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 74dd862ce8fb2..7caea3920255a 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -428,7 +428,7 @@ module {
 
 module {
   "llvm.func"() ({
-  // expected-error @below {{invalid Calling Conventions specification: cc_12}}
+  // expected-error @below {{expected one of [ccc, fastcc, coldcc, cc_10, cc_11, anyregcc, preserve_mostcc, preserve_allcc, swiftcc, cxx_fast_tlscc, tailcc, cfguard_checkcc, swifttailcc, x86_stdcallcc, x86_fastcallcc, arm_apcscc, arm_aapcscc, arm_aapcs_vfpcc, msp430_intrcc, x86_thiscallcc, ptx_kernelcc, ptx_devicecc, spir_funccc, spir_kernelcc, intel_ocl_bicc, x86_64_sysvcc, win64cc, x86_vectorcallcc, hhvmcc, hhvm_ccc, x86_intrcc, avr_intrcc, avr_builtincc, amdgpu_vscc, amdgpu_gscc, amdgpu_cscc, amdgpu_kernelcc, x86_regcallcc, amdgpu_hscc, msp430_builtincc, amdgpu_lscc, amdgpu_escc, aarch64_vectorcallcc, aarch64_sve_vectorcallcc, wasm_emscripten_invokecc, amdgpu_gfxcc, m68k_intrcc] for Calling Conventions, got: cc_12}}
   // expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
   }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
 }
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 5a005a393d8ac..4f280bde1aecc 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -535,7 +535,7 @@ func.func @allowed_cases_pass() {
 // -----
 
 func.func @disallowed_case_sticky_fail() {
-  // expected-error at +2 {{expected test::TestBitEnum to be one of: read, write, execute}}
+  // expected-error at +2 {{expected one of [read, write, execute] for a test bit enum, got: sticky}}
   // expected-error at +1 {{failed to parse TestBitEnumAttr}}
   "test.op_with_bit_enum"() {value = #test.bit_enum<sticky>} : () -> ()
 }
diff --git a/mlir/test/lib/Dialect/Test/TestEnumDefs.td b/mlir/test/lib/Dialect/Test/TestEnumDefs.td
index 1ddfca0b22315..7441ea5a9726b 100644
--- a/mlir/test/lib/Dialect/Test/TestEnumDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestEnumDefs.td
@@ -42,11 +42,10 @@ def TestEnum
   let cppNamespace = "test";
 }
 
-def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
-    I32EnumAttrCase<"a", 0>,
-    I32EnumAttrCase<"b", 1>
+def TestSimpleEnum : I32Enum<"SimpleEnum", "", [
+    I32EnumCase<"a", 0>,
+    I32EnumCase<"b", 1>
   ]> {
-  let genSpecializedAttr = 0;
   let cppNamespace = "::test";
 }
 
@@ -56,24 +55,22 @@ def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
 
 // Define the C++ enum.
 def TestBitEnum
-    : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
-        I32BitEnumAttrCaseBit<"Read", 0, "read">,
-        I32BitEnumAttrCaseBit<"Write", 1, "write">,
-        I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
+    : I32BitEnum<"TestBitEnum", "a test bit enum", [
+        I32BitEnumCaseBit<"Read", 0, "read">,
+        I32BitEnumCaseBit<"Write", 1, "write">,
+        I32BitEnumCaseBit<"Execute", 2, "execute">,
       ]> {
-  let genSpecializedAttr = 0;
   let cppNamespace = "test";
   let separator = ", ";
 }
 
 // Define an enum with a different separator
 def TestBitEnumVerticalBar
-    : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
-        I32BitEnumAttrCaseBit<"User", 0, "user">,
-        I32BitEnumAttrCaseBit<"Group", 1, "group">,
-        I32BitEnumAttrCaseBit<"Other", 2, "other">,
+    : I32BitEnum<"TestBitEnumVerticalBar", "another test bit enum", [
+        I32BitEnumCaseBit<"User", 0, "user">,
+        I32BitEnumCaseBit<"Group", 1, "group">,
+        I32BitEnumCaseBit<"Other", 2, "other">,
       ]> {
-  let genSpecializedAttr = 0;
   let cppNamespace = "test";
   let separator = " | ";
 }
diff --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td
index c3a768e42236c..8489cff7c429d 100644
--- a/mlir/test/mlir-tblgen/enums-gen.td
+++ b/mlir/test/mlir-tblgen/enums-gen.td
@@ -5,12 +5,12 @@ include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 
 // Test bit enums
-def None: I32BitEnumAttrCaseNone<"None", "none">;
-def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
-def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
-def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
+def None: I32BitEnumCaseNone<"None", "none">;
+def Bit0: I32BitEnumCaseBit<"Bit0", 0, "tagged">;
+def Bit1: I32BitEnumCaseBit<"Bit1", 1>;
+def Bit2: I32BitEnumCaseBit<"Bit2", 2>;
 def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
-def BitGroup: I32BitEnumAttrCaseGroup<"BitGroup", [
+def BitGroup: BitEnumCaseGroup<"BitGroup", [
   Bit0, Bit1
 ]>;
 
@@ -42,7 +42,7 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
 // DECL:     // Symbolize the keyword.
 // DECL:     if (::std::optional<::MyBitEnum> attr = ::symbolizeEnum<::MyBitEnum>(enumKeyword))
 // DECL:       return *attr;
-// DECL:     return parser.emitError(loc, "invalid An example bit enum specification: ") << enumKeyword;
+// DECL:     return parser.emitError(loc, "expected one of [none, tagged, Bit1, Bit2, Bit3, BitGroup] for An example bit enum, got: ") << enumKeyword;
 // DECL:   }
 
 // DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) {
@@ -73,7 +73,7 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
 
 // Test enum printer generation for non non-keyword enums.
 
-def NonKeywordBit: I32BitEnumAttrCaseBit<"Bit0", 0, "tag-ged">;
+def NonKeywordBit: I32BitEnumCaseBit<"Bit0", 0, "tag-ged">;
 def MyMixedNonKeywordBitEnum: I32BitEnumAttr<"MyMixedNonKeywordBitEnum", "An example bit enum", [
     NonKeywordBit,
     Bit1
@@ -101,3 +101,32 @@ def MyNonKeywordBitEnum: I32BitEnumAttr<"MyNonKeywordBitEnum", "An example bit e
 // DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonKeywordBitEnum value) {
 // DECL:   auto valueStr = stringifyEnum(value);
 // DECL:   return p << '"' << valueStr << '"';
+
+def MyNonQuotedPrintBitEnum
+  : I32BitEnum<"MyNonQuotedPrintBitEnum", "Example new-style bit enum",
+    [None, Bit0, Bit1, Bit2, Bit3, BitGroup]>;
+
+// DECL: struct FieldParser<::MyNonQuotedPrintBitEnum, ::MyNonQuotedPrintBitEnum> {
+// DECL:   template <typename ParserT>
+// DECL:   static FailureOr<::MyNonQuotedPrintBitEnum> parse(ParserT &parser) {
+// DECL:     ::MyNonQuotedPrintBitEnum flags = {};
+// DECL:     do {
+  // DECL:     // Parse the keyword containing a part of the enum.
+// DECL:       ::llvm::StringRef enumKeyword;
+// DECL:       auto loc = parser.getCurrentLocation();
+// DECL:       if (failed(parser.parseOptionalKeyword(&enumKeyword))) {
+// DECL:         return parser.emitError(loc, "expected keyword for Example new-style bit enum");
+// DECL:       }
+// DECL:       // Symbolize the keyword.
+// DECL:       if (::std::optional<::MyNonQuotedPrintBitEnum> flag = ::symbolizeEnum<::MyNonQuotedPrintBitEnum>(enumKeyword))
+// DECL:         flags = flags | *flag;
+// DECL:       } else {
+// DECL:         return parser.emitError(loc, "expected one of [none, tagged, Bit1, Bit2, Bit3, BitGroup] for Example new-style bit enum, got: ") << enumKeyword;
+// DECL:       }
+// DECL:     } while (::mlir::succeeded(parser.parseOptionalVerticalBar()));
+// DECL:     return flags;
+// DECL:   }
+
+// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonQuotedPrintBitEnum value) {
+// DECL:   auto valueStr = stringifyEnum(value);
+// DECL-NEXT:   return p << valueStr;
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 5d4d9e90fff67..8e2d6114e48eb 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -85,17 +85,6 @@ static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
   os << "\n";
 }
 
-/// Attempts to extract the bitwidth B from string "uintB_t" describing the
-/// type. This bitwidth information is not readily available in ODS. Returns
-/// `false` on success, `true` on failure.
-static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
-  if (!uintType.consume_front("uint"))
-    return true;
-  if (!uintType.consume_back("_t"))
-    return true;
-  return uintType.getAsInteger(/*Radix=*/10, bitwidth);
-}
-
 /// Emits an attribute builder for the given enum attribute to support automatic
 /// conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
@@ -104,12 +93,7 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
   if (!enumAttrInfo)
     return false;
 
-  int64_t bitwidth;
-  if (extractUIntBitwidth(enumInfo.getUnderlyingType(), bitwidth)) {
-    llvm::errs() << "failed to identify bitwidth of "
-                 << enumInfo.getUnderlyingType();
-    return true;
-  }
+  int64_t bitwidth = enumInfo.getBitwidth();
   os << formatv("@register_attribute_builder(\"{0}\")\n",
                 enumAttrInfo->getAttrDefName());
   os << formatv("def _{0}(x, context):\n",
@@ -140,7 +124,7 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
 static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
   os << fileHeader;
   for (const Record *it :
-       records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
+       records.getAllDerivedDefinitionsIfDefined("EnumInfo")) {
     EnumInfo enumInfo(*it);
     emitEnumClass(enumInfo, os);
     emitAttributeBuilder(enumInfo, os);
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index fa6fad156b747..9941a203bc5cb 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -77,11 +77,22 @@ static void emitParserPrinter(const EnumInfo &enumInfo, StringRef qualName,
 
   // Check which cases shouldn't be printed using a keyword.
   llvm::BitVector nonKeywordCases(cases.size());
-  for (auto [index, caseVal] : llvm::enumerate(cases))
-    if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
-      nonKeywordCases.set(index);
-
-  // Generate the parser and the start of the printer for the enum.
+  std::string casesList;
+  llvm::raw_string_ostream caseListOs(casesList);
+  caseListOs << "[";
+  llvm::interleaveComma(llvm::enumerate(cases), caseListOs,
+                        [&](auto enumerant) {
+                          StringRef name = enumerant.value().getStr();
+                          if (!mlir::tblgen::canFormatStringAsKeyword(name)) {
+                            nonKeywordCases.set(enumerant.index());
+                            caseListOs << "\\\"" << name << "\\\"";
+                          }
+                          caseListOs << name;
+                        });
+  caseListOs << "]";
+
+  // Generate the parser and the start of the printer for the enum, excluding
+  // non-quoted bit enums.
   const char *parsedAndPrinterStart = R"(
 namespace mlir {
 template <typename T, typename>
@@ -100,7 +111,7 @@ struct FieldParser<{0}, {0}> {{
     // Symbolize the keyword.
     if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
       return *attr;
-    return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
+    return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
   }
 };
 
@@ -121,7 +132,7 @@ struct FieldParser<std::optional<{0}>, std::optional<{0}>> {{
     // Symbolize the keyword.
     if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
       return attr;
-    return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
+    return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
   }
 };
 } // namespace mlir
@@ -131,8 +142,94 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
   auto valueStr = stringifyEnum(value);
 )";
 
-  os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
-                enumInfo.getSummary());
+  const char *parsedAndPrinterStartUnquotedBitEnum = R"(
+  namespace mlir {
+  template <typename T, typename>
+  struct FieldParser;
+
+  template<>
+  struct FieldParser<{0}, {0}> {{
+    template <typename ParserT>
+    static FailureOr<{0}> parse(ParserT &parser) {{
+      {0} flags = {{};
+      do {{
+        // Parse the keyword containing a part of the enum.
+        ::llvm::StringRef enumKeyword;
+        auto loc = parser.getCurrentLocation();
+        if (failed(parser.parseOptionalKeyword(&enumKeyword))) {{
+          return parser.emitError(loc, "expected keyword for {2}");
+        }
+
+        // Symbolize the keyword.
+        if (::std::optional<{0}> flag = {1}::symbolizeEnum<{0}>(enumKeyword)) {{
+          flags = flags | *flag;
+        } else {{
+          return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
+        }
+      } while (::mlir::succeeded(parser.{5}()));
+      return flags;
+    }
+  };
+
+  /// Support for std::optional, useful in attribute/type definition where the enum is
+  /// used as:
+  ///
+  ///    let parameters = (ins OptionalParameter<"std::optional<TheEnumName>">:$value);
+  template<>
+  struct FieldParser<std::optional<{0}>, std::optional<{0}>> {{
+    template <typename ParserT>
+    static FailureOr<std::optional<{0}>> parse(ParserT &parser) {{
+      {0} flags = {{};
+      bool firstIter = true;
+      do {{
+        // Parse the keyword containing a part of the enum.
+        ::llvm::StringRef enumKeyword;
+        auto loc = parser.getCurrentLocation();
+        if (failed(parser.parseOptionalKeyword(&enumKeyword))) {{
+          if (firstIter)
+            return std::optional<{0}>{{};
+          return parser.emitError(loc, "expected keyword for {2} after '{4}'");
+        }
+        firstIter = false;
+
+        // Symbolize the keyword.
+        if (::std::optional<{0}> flag = {1}::symbolizeEnum<{0}>(enumKeyword)) {{
+          flags = flags | *flag;
+        } else {{
+          return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
+        }
+      } while(::mlir::succeeded(parser.{5}()));
+      return std::optional<{0}>{{flags};
+    }
+  };
+  } // namespace mlir
+
+  namespace llvm {
+  inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
+    auto valueStr = stringifyEnum(value);
+  )";
+
+  bool isNewStyleBitEnum =
+      enumInfo.isBitEnum() && !enumInfo.printBitEnumQuoted();
+
+  if (isNewStyleBitEnum) {
+    if (nonKeywordCases.any())
+      return PrintFatalError(
+          "bit enum " + qualName +
+          " cannot be printed unquoted with cases that cannot be keywords");
+    StringRef separator = enumInfo.getDef().getValueAsString("separator");
+    StringRef parseSeparatorFn =
+        llvm::StringSwitch<StringRef>(separator.trim())
+            .Case("|", "parseOptionalVerticalBar")
+            .Case(",", "parseOptionalComma")
+            .Default("error, enum seperator must be '|' or ','");
+    os << formatv(parsedAndPrinterStartUnquotedBitEnum, qualName, cppNamespace,
+                  enumInfo.getSummary(), casesList, separator,
+                  parseSeparatorFn);
+  } else {
+    os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
+                  enumInfo.getSummary(), casesList);
+  }
 
   // If all cases require a string, always wrap.
   if (nonKeywordCases.all()) {
@@ -160,7 +257,10 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
 
     // If this is a bit enum, conservatively print the string form if the value
     // is not a power of two (i.e. not a single bit case) and not a known case.
-  } else if (enumInfo.isBitEnum()) {
+    // Only do this if we're using the old-style parser that parses the enum as
+    // one keyword, as opposed to the new form, where we can print the value
+    // as-is.
+  } else if (enumInfo.isBitEnum() && !isNewStyleBitEnum) {
     // Process the known multi-bit cases that use valid keywords.
     SmallVector<EnumCase *> validMultiBitCases;
     for (auto [index, caseVal] : llvm::enumerate(cases)) {
@@ -670,7 +770,7 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("Enum Utility Declarations", os, records);
 
   for (const Record *def :
-       records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
+       records.getAllDerivedDefinitionsIfDefined("EnumInfo"))
     emitEnumDecl(*def, os);
 
   return false;
@@ -708,7 +808,7 @@ static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("Enum Utility Definitions", os, records);
 
   for (const Record *def :
-       records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
+       records.getAllDerivedDefinitionsIfDefined("EnumInfo"))
     emitEnumDef(*def, os);
 
   return false;
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index f53aebb302dc9..077f9d1ea2b13 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -406,7 +406,7 @@ static void emitEnumDoc(const EnumInfo &def, raw_ostream &os) {
 
 static void emitEnumDoc(const RecordKeeper &records, raw_ostream &os) {
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
-  for (const Record *def : records.getAllDerivedDefinitions("EnumAttrInfo"))
+  for (const Record *def : records.getAllDerivedDefinitions("EnumInfo"))
     emitEnumDoc(EnumInfo(def), os);
 }
 
@@ -526,7 +526,7 @@ static bool emitDialectDoc(const RecordKeeper &records, raw_ostream &os) {
   auto typeDefs = records.getAllDerivedDefinitionsIfDefined("DialectType");
   auto typeDefDefs = records.getAllDerivedDefinitionsIfDefined("TypeDef");
   auto attrDefDefs = records.getAllDerivedDefinitionsIfDefined("AttrDef");
-  auto enumDefs = records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
+  auto enumDefs = records.getAllDerivedDefinitionsIfDefined("EnumInfo");
 
   std::vector<Attribute> dialectAttrs;
   std::vector<AttrDef> dialectAttrDefs;
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 7a6189c09f426..f94ed17aeb4e0 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -455,7 +455,7 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os,
                              records);
 
-  auto defs = records.getAllDerivedDefinitions("EnumAttrInfo");
+  auto defs = records.getAllDerivedDefinitions("EnumInfo");
   for (const auto *def : defs)
     emitEnumDecl(*def, os);
 
@@ -487,7 +487,7 @@ static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os,
                              records);
 
-  auto defs = records.getAllDerivedDefinitions("EnumAttrInfo");
+  auto defs = records.getAllDerivedDefinitions("EnumInfo");
   for (const auto *def : defs)
     emitEnumDef(*def, os);
 
@@ -1262,7 +1262,7 @@ static void emitEnumGetAttrNameFnDefn(const EnumInfo &enumInfo,
 static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os, records);
 
-  auto defs = records.getAllDerivedDefinitions("EnumAttrInfo");
+  auto defs = records.getAllDerivedDefinitions("EnumInfo");
   os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
   os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
   emitEnumGetAttrNameFnDecl(os);
diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index 99ed3489b4cbd..d2d0b410f52df 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -288,11 +288,11 @@ def get_availability_spec(enum_case, for_op, for_cap):
 
 
 def gen_operand_kind_enum_attr(operand_kind):
-    """Generates the TableGen EnumAttr definition for the given operand kind.
+    """Generates the TableGen EnumInfo definition for the given operand kind.
 
     Returns:
       - The operand kind's name
-      - A string containing the TableGen EnumAttr definition
+      - A string containing the TableGen EnumInfo definition
     """
     if "enumerants" not in operand_kind:
         return "", ""



More information about the Mlir-commits mailing list