[Mlir-commits] [mlir] [mlir] Decouple enum generation from attributes, adding EnumInfo and EnumCase (PR #132148)

Krzysztof Drewniak llvmlistbot at llvm.org
Wed Mar 26 22:19:34 PDT 2025


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/132148

>From 57a1e0841847694a14462c4bb16f748176a386e4 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <krzysdrewniak at gmail.com>
Date: Thu, 20 Mar 2025 23:06:17 -0700
Subject: [PATCH] [mlir] Decouple enum generation from attributes, adding
 EnumInfo and EnumCase

This commit pulls apart the inherent attribute dependence of classes
like EnumAttrInfo and EnumAttrCase, factoring them out into simpler
EnumCase and EnumInfo variants. This allows specifying the cases of an
enum without needing to make the cases, or the EnumInfo itself, a
subclass of SignlessIntegerAttrBase.

The existing classes are retained as subclasses of the new ones, both
for backwards compatibility and to allow attribute-specific
information.

In addition, the new BitEnum class changes its default printer/parser
behavior: cases when multiple keywords appear, like having both nuw
and nsw in overflow flags, will no longer be quoted by the operator<<,
and the FieldParser instance will now expect multiple keywords. All
instances of BitEnumAttr retain the old behavior.
---
 mlir/docs/DefiningDialects/Operations.md      |  66 +++--
 mlir/include/mlir/IR/EnumAttr.td              | 271 +++++++++++++-----
 mlir/include/mlir/TableGen/EnumInfo.h         |   4 +
 mlir/lib/TableGen/EnumInfo.cpp                |  16 +-
 mlir/lib/TableGen/Pattern.cpp                 |   2 +-
 mlir/test/Dialect/LLVMIR/func.mlir            |   2 +-
 mlir/test/IR/attribute.mlir                   |   2 +-
 mlir/test/lib/Dialect/Test/TestEnumDefs.td    |  25 +-
 mlir/test/mlir-tblgen/enums-gen.td            |  43 ++-
 .../mlir-tblgen/EnumPythonBindingGen.cpp      |  20 +-
 mlir/tools/mlir-tblgen/EnumsGen.cpp           | 124 +++++++-
 mlir/tools/mlir-tblgen/OpDocGen.cpp           |   4 +-
 mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp      |   6 +-
 mlir/utils/spirv/gen_spirv_dialect.py         |   4 +-
 14 files changed, 431 insertions(+), 158 deletions(-)

diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md
index 528070cd3ebff..fafda816a3881 100644
--- a/mlir/docs/DefiningDialects/Operations.md
+++ b/mlir/docs/DefiningDialects/Operations.md
@@ -1498,22 +1498,17 @@ optionality, default values, etc.:
 *   `AllAttrOf`: adapts an attribute with
     [multiple constraints](#combining-constraints).
 
-### Enum attributes
+## Enum definition
 
-Some attributes can only take values from a predefined enum, e.g., the
-comparison kind of a comparison op. To define such attributes, ODS provides
-several mechanisms: `IntEnumAttr`, and `BitEnumAttr`.
+MLIR is capabable of generating C++ enums, both those that represent a set
+of values drawn from a list or that can hold a combination of flags
+using the `IntEnum` and `BitEnum` classes, respectively.
 
-*   `IntEnumAttr`: each enum case is an integer, the attribute is stored as a
-    [`IntegerAttr`][IntegerAttr] in the op.
-*   `BitEnumAttr`: each enum case is a either the empty case, a single bit,
-    or a group of single bits, and the attribute is stored as a
-    [`IntegerAttr`][IntegerAttr] in the op.
-
-All these `*EnumAttr` attributes require fully specifying all of the allowed
-cases via their corresponding `*EnumAttrCase`. With this, ODS is able to
+All these `IntEnum` and `BitEnum` classes require fully specifying all of the allowed
+cases via a `EnumCase` or `BitEnumCase` subclass, respectively. With this, ODS is able to
 generate additional verification to only accept allowed cases. To facilitate the
-interaction between `*EnumAttr`s and their C++ consumers, the
+interaction between tablegen enums and the attributes or properties that wrap them and
+to make them easier to use in C++, the
 [`EnumsGen`][EnumsGen] TableGen backend can generate a few common utilities: a
 C++ enum class, `llvm::DenseMapInfo` for the enum class, conversion functions
 from/to strings. This is controlled via the `-gen-enum-decls` and
@@ -1522,10 +1517,10 @@ from/to strings. This is controlled via the `-gen-enum-decls` and
 For example, given the following `EnumAttr`:
 
 ```tablegen
-def Case15: I32EnumAttrCase<"Case15", 15>;
-def Case20: I32EnumAttrCase<"Case20", 20>;
+def Case15: I32EnumCase<"Case15", 15>;
+def Case20: I32EnumCase<"Case20", 20>;
 
-def MyIntEnum: I32EnumAttr<"MyIntEnum", "An example int enum",
+def MyIntEnum: I32Enum<"MyIntEnum", "An example int enum",
                            [Case15, Case20]> {
   let cppNamespace = "Outer::Inner";
   let stringToSymbolFnName = "ConvertToEnum";
@@ -1611,14 +1606,17 @@ std::optional<MyIntEnum> symbolizeMyIntEnum(uint32_t value) {
 Similarly for the following `BitEnumAttr` definition:
 
 ```tablegen
-def None: I32BitEnumAttrCaseNone<"None">;
-def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
-def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
-def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
-def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
-
-def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum",
-                           [None, Bit0, Bit1, Bit2, Bit3]>;
+def None: I32BitEnumCaseNone<"None">;
+def Bit0: I32BitEnumCaseBit<"Bit0", 0, "tagged">;
+def Bit1: I32BitEnumCaseBit<"Bit1", 1>;
+def Bit2: I32BitEnumCaseBit<"Bit2", 2>;
+def Bit3: I32BitEnumCaseBit<"Bit3", 3>;
+
+def MyBitEnum: I32BitEnum<"MyBitEnum", "An example bit enum",
+                           [None, Bit0, Bit1, Bit2, Bit3]> {
+  // Note: this is the default value, and is listed for illustrative purposes.
+  let separator = "|";
+}
 ```
 
 We can have:
@@ -1738,6 +1736,26 @@ std::optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
 }
 ```
 
+### Wrapping enums in attributes
+
+There are several mechanisms for creating an `Attribute` whose values are
+taken from a `*Enum`.
+
+The most common of these is to use the `EnumAttr` class, which takes
+an `EnumInfo` (either a `IntEnum` or `BitEnum`) as a parameter and constructs
+an attribute that holds one argument - value of the enum. This attribute
+is defined within a dialect and can have its assembly format customized to,
+for example, print angle brackets around the enum value or assign a mnemonic.
+
+An older form involves using the `*IntEnumAttr` and `*BitEnumATtr` classes
+and their corresponding `*EnumAttrCase` classes (which can be used
+anywhere a `*EnumCase` is needed). These classes store their values
+as a `SignlessIntegerAttr` of their bitwidth, imposing the constraint on it
+that it has a value within the valid range of the enum. If their
+`genSpecializedAttr` parameter is set, they will also generate a
+wrapper attribute instead of using a bare signless integer attribute
+for storage.
+
 ## Debugging Tips
 
 ### Run `mlir-tblgen` to see the generated content
diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 9fec28f03ec28..e5406546b1950 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -14,8 +14,8 @@ include "mlir/IR/AttrTypeBase.td"
 //===----------------------------------------------------------------------===//
 // Enum attribute kinds
 
-// Additional information for an enum attribute case.
-class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
+// Additional information for an enum case.
+class EnumCase<string sym, int intVal, string strVal, int widthVal> {
   // The C++ enumerant symbol.
   string symbol = sym;
 
@@ -26,29 +26,56 @@ class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
 
   // The string representation of the enumerant. May be the same as symbol.
   string str = strVal;
+
+  // The bitwidth of the enum.
+  int width = widthVal;
 }
 
 // An enum attribute case stored with IntegerAttr, which has an integer value,
 // its representation as a string and a C++ symbol name which may be different.
+// Not needed when using the newer `EnumCase` form for defining enum cases.
 class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
-    EnumAttrCaseInfo<sym, intVal, strVal>,
+    EnumCase<sym, intVal, strVal, intType.bitwidth>,
     SignlessIntegerAttrBase<intType, "case " # strVal> {
   let predicate =
     CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() == " # intVal>;
 }
 
-// Cases of integer enum attributes with a specific type. By default, the string
+// Cases of integer enums with a specific type. By default, the string
 // representation is the same as the C++ symbol name.
+class I32EnumCase<string sym, int val, string str = sym>
+  : EnumCase<sym, val, str, 32>;
+class I64EnumCase<string sym, int val, string str = sym>
+  : EnumCase<sym, val, str, 64>;
+
+// Cases of integer enum attributes with a specific type. By default, the string
+// representation is the same as the C++ symbol name. These forms
+// are not needed when using the newer `EnumCase` form.
 class I32EnumAttrCase<string sym, int val, string str = sym>
     : IntEnumAttrCaseBase<I32, sym, str, val>;
 class I64EnumAttrCase<string sym, int val, string str = sym>
     : IntEnumAttrCaseBase<I64, sym, str, val>;
 
-// A bit enum case stored with an IntegerAttr. `val` here is *not* the ordinal
-// number of a bit that is set. It is an integer value with bits set to match
-// the case.
+// A bit enum case. `val` here is *not* the ordinal number of a bit
+// that is set. It is an integer value with bits set to match the case.
+class BitEnumCaseBase<string sym, int val, string str, int width> :
+    EnumCase<sym, val, str, width>;
+// Bit enum attr cases. The string representation is the same as the C++ symbol
+// name unless otherwise specified.
+class I8BitEnumCase<string sym, int val, string str = sym>
+  : BitEnumCaseBase<sym, val, str, 8>;
+class I16BitEnumCase<string sym, int val, string str = sym>
+  : BitEnumCaseBase<sym, val, str, 16>;
+class I32BitEnumCase<string sym, int val, string str = sym>
+  : BitEnumCaseBase<sym, val, str, 32>;
+class I64BitEnumCase<string sym, int val, string str = sym>
+  : BitEnumCaseBase<sym, val, str, 64>;
+
+// A form of `BitEnumCaseBase` that also inherits from `Attr` and encodes
+// the width of the enum, which was defined when enums were always
+// stored in attributes.
 class BitEnumAttrCaseBase<I intType, string sym, int val, string str = sym> :
-    EnumAttrCaseInfo<sym, val, str>,
+    BitEnumCaseBase<sym, val, str, intType.bitwidth>,
     SignlessIntegerAttrBase<intType, "case " #str>;
 
 class I8BitEnumAttrCase<string sym, int val, string str = sym>
@@ -61,6 +88,19 @@ class I64BitEnumAttrCase<string sym, int val, string str = sym>
     : BitEnumAttrCaseBase<I64, sym, val, str>;
 
 // The special bit enum case with no bits set (i.e. value = 0).
+class BitEnumCaseNone<string sym, string str, int width>
+    : BitEnumCaseBase<sym, 0, str, width>;
+
+class I8BitEnumCaseNone<string sym, string str = sym>
+  : BitEnumCaseNone<sym, str, 8>;
+class I16BitEnumCaseNone<string sym, string str = sym>
+  : BitEnumCaseNone<sym, str, 16>;
+class I32BitEnumCaseNone<string sym, string str = sym>
+  : BitEnumCaseNone<sym, str, 32>;
+class I64BitEnumCaseNone<string sym, string str = sym>
+  : BitEnumCaseNone<sym, str, 64>;
+
+// Older forms, used when enums were necessarily attributes.
 class I8BitEnumAttrCaseNone<string sym, string str = sym>
     : I8BitEnumAttrCase<sym, 0, str>;
 class I16BitEnumAttrCaseNone<string sym, string str = sym>
@@ -70,6 +110,24 @@ class I32BitEnumAttrCaseNone<string sym, string str = sym>
 class I64BitEnumAttrCaseNone<string sym, string str = sym>
     : I64BitEnumAttrCase<sym, 0, str>;
 
+// A bit enum case for a single bit, specified by a bit position `pos`.
+// The `pos` argument refers to the index of the bit, and is limited
+// to be in the range [0, width).
+class BitEnumCaseBit<string sym, int pos, string str, int width>
+    : BitEnumCaseBase<sym, !shl(1, pos), str, width> {
+  assert !and(!ge(pos, 0), !lt(pos, width)),
+      "bit position larger than underlying storage";
+}
+
+class I8BitEnumCaseBit<string sym, int pos, string str = sym>
+    : BitEnumCaseBit<sym, pos, str, 8>;
+class I16BitEnumCaseBit<string sym, int pos, string str = sym>
+    : BitEnumCaseBit<sym, pos, str, 16>;
+class I32BitEnumCaseBit<string sym, int pos, string str = sym>
+    : BitEnumCaseBit<sym, pos, str, 32>;
+class I64BitEnumCaseBit<string sym, int pos, string str = sym>
+    : BitEnumCaseBit<sym, pos, str, 64>;
+
 // A bit enum case for a single bit, specified by a bit position.
 // The pos argument refers to the index of the bit, and is limited
 // to be in the range [0, bitwidth).
@@ -90,12 +148,17 @@ class I64BitEnumAttrCaseBit<string sym, int pos, string str = sym>
 
 // A bit enum case for a group/list of previously declared cases, providing
 // a convenient alias for that group.
+class BitEnumCaseGroup<string sym, list<BitEnumCaseBase> cases, string str = sym>
+    : BitEnumCaseBase<sym,
+      !foldl(0, cases, value, bitcase, !or(value, bitcase.value)),
+      str, !head(cases).width>;
+
+// The attribute-only form of `BitEnumCaseGroup`.
 class BitEnumAttrCaseGroup<I intType, string sym,
-                           list<BitEnumAttrCaseBase> cases, string str = sym>
+                           list<BitEnumCaseBase> cases, string str = sym>
     : BitEnumAttrCaseBase<intType, sym,
           !foldl(0, cases, value, bitcase, !or(value, bitcase.value)),
           str>;
-
 class I8BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
                               string str = sym>
     : BitEnumAttrCaseGroup<I8, sym, cases, str>;
@@ -109,29 +172,36 @@ class I64BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
                               string str = sym>
     : BitEnumAttrCaseGroup<I64, sym, cases, str>;
 
-// Additional information for an enum attribute.
-class EnumAttrInfo<
-    string name, list<EnumAttrCaseInfo> cases, Attr baseClass> :
-      Attr<baseClass.predicate, baseClass.summary> {
-
+// Information describing an enum and the functions that should be generated for it.
+class EnumInfo<string name, string summaryValue, list<EnumCase> cases, int width> {
+  string summary = summaryValue;
   // Generate a description of this enums members for the MLIR docs.
-  let description =
+  string description =
         "Enum cases:\n" # !interleave(
           !foreach(case, cases,
               "* " # case.str  # " (`" # case.symbol # "`)"), "\n");
 
+  // The C++ namespace for this enum
+  string cppNamespace = "";
+
   // The C++ enum class name
   string className = name;
 
+  // C++ type wrapped by attribute
+  string cppType = cppNamespace # "::" # className;
+
   // List of all accepted cases
-  list<EnumAttrCaseInfo> enumerants = cases;
+  list<EnumCase> enumerants = cases;
 
   // The following fields are only used by the EnumsGen backend to generate
   // an enum class definition and conversion utility functions.
 
+  // The bitwidth underlying the class
+  int bitwidth = width;
+
   // The underlying type for the C++ enum class. An empty string mean the
   // underlying type is not explicitly specified.
-  string underlyingType = "";
+  string underlyingType = "uint" # width # "_t";
 
   // The name of the utility function that converts a value of the underlying
   // type to the corresponding symbol. It will have the following signature:
@@ -165,6 +235,15 @@ class EnumAttrInfo<
   // static constexpr unsigned <fn-name>();
   // ```
   string maxEnumValFnName = "getMaxEnumValFor" # name;
+}
+
+// A wrapper around `EnumInfo` that also makes the Enum an attribute
+// if `genSeecializedAttr` is 1 (though `EnumAttr` is the preferred means
+// to accomplish this) or declares that the enum will be stored in an attribute.
+class EnumAttrInfo<
+    string name, list<EnumCase> cases, SignlessIntegerAttrBase baseClass> :
+      EnumInfo<name, baseClass.summary, cases, !cast<I>(baseClass.valueType).bitwidth>,
+      Attr<baseClass.predicate, baseClass.summary> {
 
   // Generate specialized Attribute class
   bit genSpecializedAttr = 1;
@@ -188,15 +267,25 @@ class EnumAttrInfo<
     baseAttrClass.constBuilderCall);
   let valueType = baseAttrClass.valueType;
 
-  // C++ type wrapped by attribute
-  string cppType = cppNamespace # "::" # className;
-
   // Parser and printer code used by the EnumParameter class, to be provided by
   // derived classes
   string parameterParser = ?;
   string parameterPrinter = ?;
 }
 
+// An attribute holding a single integer value.
+class IntEnum<string name, string summary, list<EnumCase> cases, int width>
+    : EnumInfo<name,
+      !if(!empty(summary), "allowed i" # width # " cases: " #
+          !interleave(!foreach(case, cases, case.value), ", "),
+          summary),
+      cases, width>;
+
+class I32Enum<string name, string summary, list<EnumCase> cases>
+    : IntEnum<name, summary, cases, 32>;
+class I64Enum<string name, string summary, list<EnumCase> cases>
+    : IntEnum<name, summary, cases, 32>;
+
 // An enum attribute backed by IntegerAttr.
 //
 // Op attributes of this kind are stored as IntegerAttr. Extra verification will
@@ -245,13 +334,73 @@ class I64EnumAttr<string name, string summary, list<I64EnumAttrCase> cases> :
   let underlyingType = "uint64_t";
 }
 
+// The base mixin for bit enums that are stored as an integer.
+// This is used by both BitEnum and BitEnumAttr, which need to have a set of
+// extra properties that bit enums have which normal enums don't. However,
+// we can't just use BitEnum as a base class of BitEnumAttr, since BitEnumAttr
+// also inherits from EnumAttrInfo, causing double inheritance of EnumInfo.
+class BitEnumBase<list<BitEnumCaseBase> cases> {
+  // Determine "valid" bits from enum cases for error checking
+  int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value));
+
+  // The delimiter used to separate bit enum cases in strings. Only "|" and
+  // "," (along with optional spaces) are supported due to the use of the
+  // parseSeparatorFn in parameterParser below.
+  // Spaces in the separator string are used for printing, but will be optional
+  // for parsing.
+  string separator = "|";
+  assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)),
+      "separator must contain '|' or ',' for parameter parsing";
+
+  // Print the "primary group" only for bits that are members of case groups
+  // that have all bits present. When the value is 0, printing will display both
+  // both individual bit case names AND the names for all groups that the bit is
+  // contained in. When the value is 1, for each bit that is set AND is a member
+  // of a group with all bits set, only the "primary group" (i.e. the first
+  // group with all bits set in reverse declaration order) will be printed (for
+  // conciseness).
+  bit printBitEnumPrimaryGroups = 0;
+
+  // 1 if the operator<< for this enum should put quotes around values with
+  // multiple entries. Off by default in the general case but on for BitEnumAttrs
+  // since that was the original behavior.
+  bit printBitEnumQuoted = 0;
+}
+
+// A bit enum stored as an integer.
+//
+// Enums of these kind are staored as an integer. Attributes or properties deriving
+// from this enum will have additional verification generated on them to make sure
+// only allowed bits are set. Helper methods are generated to parse a sring of enum
+// values generated by the specified separator to a symbol and vice versa.
+class BitEnum<string name, string summary, list<BitEnumCaseBase> cases, int width>
+    : EnumInfo<name, summary, cases, width>, BitEnumBase<cases> {
+  // We need to return a string because we may concatenate symbols for multiple
+  // bits together.
+  let symbolToStringFnRetType = "std::string";
+}
+
+class I8BitEnum<string name, string summary,
+                     list<BitEnumCaseBase> cases>
+    : BitEnum<name, summary, cases, 8>;
+class I16BitEnum<string name, string summary,
+                     list<BitEnumCaseBase> cases>
+    : BitEnum<name, summary, cases, 16>;
+class I32BitEnum<string name, string summary,
+                     list<BitEnumCaseBase> cases>
+    : BitEnum<name, summary, cases, 32>;
+
+class I64BitEnum<string name, string summary,
+                     list<BitEnumCaseBase> cases>
+    : BitEnum<name, summary, cases, 64>;
+
 // A bit enum stored with an IntegerAttr.
 //
 // Op attributes of this kind are stored as IntegerAttr. Extra verification will
 // be generated on the integer to make sure only allowed bits are set. Besides,
 // helper methods are generated to parse a string separated with a specified
 // delimiter to a symbol and vice versa.
-class BitEnumAttrBase<I intType, list<BitEnumAttrCaseBase> cases,
+class BitEnumAttrBase<I intType, list<BitEnumCaseBase> cases,
                       string summary>
     : SignlessIntegerAttrBase<intType, summary> {
   let predicate = And<[
@@ -264,24 +413,13 @@ class BitEnumAttrBase<I intType, list<BitEnumAttrCaseBase> cases,
 }
 
 class BitEnumAttr<I intType, string name, string summary,
-                  list<BitEnumAttrCaseBase> cases>
-    : EnumAttrInfo<name, cases, BitEnumAttrBase<intType, cases, summary>> {
-  // Determine "valid" bits from enum cases for error checking
-  int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value));
-
+                  list<BitEnumCaseBase> cases>
+    : EnumAttrInfo<name, cases, BitEnumAttrBase<intType, cases, summary>>,
+      BitEnumBase<cases> {
   // We need to return a string because we may concatenate symbols for multiple
   // bits together.
   let symbolToStringFnRetType = "std::string";
 
-  // The delimiter used to separate bit enum cases in strings. Only "|" and
-  // "," (along with optional spaces) are supported due to the use of the
-  // parseSeparatorFn in parameterParser below.
-  // Spaces in the separator string are used for printing, but will be optional
-  // for parsing.
-  string separator = "|";
-  assert !or(!ge(!find(separator, "|"), 0), !ge(!find(separator, ","), 0)),
-      "separator must contain '|' or ',' for parameter parsing";
-
   // Parsing function that corresponds to the enum separator. Only
   // "," and "|" are supported by this definition.
   string parseSeparatorFn = !if(!ge(!find(separator, "|"), 0),
@@ -312,36 +450,30 @@ class BitEnumAttr<I intType, string name, string summary,
   // Print the enum by calling `symbolToString`.
   let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)";
 
-  // Print the "primary group" only for bits that are members of case groups
-  // that have all bits present. When the value is 0, printing will display both
-  // both individual bit case names AND the names for all groups that the bit is
-  // contained in. When the value is 1, for each bit that is set AND is a member
-  // of a group with all bits set, only the "primary group" (i.e. the first
-  // group with all bits set in reverse declaration order) will be printed (for
-  // conciseness).
-  bit printBitEnumPrimaryGroups = 0;
+  // Use old-style operator<< and FieldParser for compatibility
+  let printBitEnumQuoted = 1;
 }
 
 class I8BitEnumAttr<string name, string summary,
-                     list<BitEnumAttrCaseBase> cases>
+                     list<BitEnumCaseBase> cases>
     : BitEnumAttr<I8, name, summary, cases> {
   let underlyingType = "uint8_t";
 }
 
 class I16BitEnumAttr<string name, string summary,
-                     list<BitEnumAttrCaseBase> cases>
+                     list<BitEnumCaseBase> cases>
     : BitEnumAttr<I16, name, summary, cases> {
   let underlyingType = "uint16_t";
 }
 
 class I32BitEnumAttr<string name, string summary,
-                     list<BitEnumAttrCaseBase> cases>
+                     list<BitEnumCaseBase> cases>
     : BitEnumAttr<I32, name, summary, cases> {
   let underlyingType = "uint32_t";
 }
 
 class I64BitEnumAttr<string name, string summary,
-                     list<BitEnumAttrCaseBase> cases>
+                     list<BitEnumCaseBase> cases>
     : BitEnumAttr<I64, name, summary, cases> {
   let underlyingType = "uint64_t";
 }
@@ -349,11 +481,13 @@ class I64BitEnumAttr<string name, string summary,
 // A C++ enum as an attribute parameter. The parameter implements a parser and
 // printer for the enum by dispatching calls to `stringToSymbol` and
 // `symbolToString`.
-class EnumParameter<EnumAttrInfo enumInfo>
+class EnumParameter<EnumInfo enumInfo>
     : AttrParameter<enumInfo.cppNamespace # "::" # enumInfo.className,
                     "an enum of type " # enumInfo.className> {
-  let parser = enumInfo.parameterParser;
-  let printer = enumInfo.parameterPrinter;
+  let parser = !if(!isa<EnumAttrInfo>(enumInfo),
+    !cast<EnumAttrInfo>(enumInfo).parameterParser, ?);
+  let printer = !if(!isa<EnumAttrInfo>(enumInfo),
+    !cast<EnumAttrInfo>(enumInfo).parameterPrinter, ?);
 }
 
 // An attribute backed by a C++ enum. The attribute contains a single
@@ -384,14 +518,14 @@ class EnumParameter<EnumAttrInfo enumInfo>
 // The op will appear in the IR as `my_dialect.my_op first`. However, the
 // generic format of the attribute will be `#my_dialect<"enum first">`. Override
 // the attribute's assembly format as required.
-class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
+class EnumAttr<Dialect dialect, EnumInfo enumInfo, string name = "",
                list <Trait> traits = []>
     : AttrDef<dialect, enumInfo.className, traits> {
   let summary = enumInfo.summary;
   let description = enumInfo.description;
 
   // The backing enumeration.
-  EnumAttrInfo enum = enumInfo;
+  EnumInfo enum = enumInfo;
 
   // Inherit the C++ namespace from the enum.
   let cppNamespace = enumInfo.cppNamespace;
@@ -417,41 +551,42 @@ class EnumAttr<Dialect dialect, EnumAttrInfo enumInfo, string name = "",
   let assemblyFormat = "$value";
 }
 
-class _symbolToValue<EnumAttrInfo enumAttrInfo, string case> {
+class _symbolToValue<EnumInfo enumInfo, string case> {
   defvar cases =
-    !filter(iter, enumAttrInfo.enumerants, !eq(iter.str, case));
+    !filter(iter, enumInfo.enumerants, !eq(iter.str, case));
 
   assert !not(!empty(cases)), "failed to find enum-case '" # case # "'";
 
   // `!empty` check to not cause an error if the cases are empty.
   // The assertion catches the issue later and emits a proper error message.
-  string value = enumAttrInfo.cppType # "::"
+  string value = enumInfo.cppType # "::"
     # !if(!empty(cases), "", !head(cases).symbol);
 }
 
-class _bitSymbolsToValue<BitEnumAttr bitEnumAttr, string case> {
+class _bitSymbolsToValue<EnumInfo bitEnum, string case> {
+  assert !isa<BitEnumBase>(bitEnum), "_bitSymbolsToValue not given a bit enum";
   defvar pos = !find(case, "|");
 
   // Recursive instantiation looking up the symbol before the `|` in
   // enum cases.
   string value = !if(
-    !eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnumAttr, case>.value,
-    /*rec=*/_symbolToValue<bitEnumAttr, !substr(case, 0, pos)>.value # "|"
-    # _bitSymbolsToValue<bitEnumAttr, !substr(case, !add(pos, 1))>.value
+    !eq(pos, -1), /*baseCase=*/_symbolToValue<bitEnum, case>.value,
+    /*rec=*/_symbolToValue<bitEnum, !substr(case, 0, pos)>.value # "|"
+    # _bitSymbolsToValue<bitEnum, !substr(case, !add(pos, 1))>.value
   );
 }
 
 class ConstantEnumCaseBase<Attr attribute,
-    EnumAttrInfo enumAttrInfo, string case>
+    EnumInfo enumInfo, string case>
   : ConstantAttr<attribute,
-  !if(!isa<BitEnumAttr>(enumAttrInfo),
-    _bitSymbolsToValue<!cast<BitEnumAttr>(enumAttrInfo), case>.value,
-    _symbolToValue<enumAttrInfo, case>.value
+  !if(!isa<BitEnumBase>(enumInfo),
+    _bitSymbolsToValue<enumInfo, case>.value,
+    _symbolToValue<enumInfo, case>.value
   )
 >;
 
 /// Attribute constraint matching a constant enum case. `attribute` should be
-/// one of `EnumAttrInfo` or `EnumAttr` and `symbol` the string representation
+/// one of `EnumInfo` or `EnumAttr` and `symbol` the string representation
 /// of an enum case. Multiple enum values of a bit-enum can be combined using
 /// `|` as a separator. Note that there mustn't be any whitespace around the
 /// separator.
@@ -463,10 +598,10 @@ class ConstantEnumCaseBase<Attr attribute,
 /// * ConstantEnumCase<Arith_CmpIPredicateAttr, "slt">
 class ConstantEnumCase<Attr attribute, string case>
   : ConstantEnumCaseBase<attribute,
-    !if(!isa<EnumAttrInfo>(attribute), !cast<EnumAttrInfo>(attribute),
+    !if(!isa<EnumInfo>(attribute), !cast<EnumInfo>(attribute),
           !cast<EnumAttr>(attribute).enum), case> {
-  assert !or(!isa<EnumAttr>(attribute), !isa<EnumAttrInfo>(attribute)),
-    "attribute must be one of 'EnumAttr' or 'EnumAttrInfo'";
+  assert !or(!isa<EnumAttr>(attribute), !isa<EnumInfo>(attribute)),
+    "attribute must be one of 'EnumAttr' or 'EnumInfo'";
 }
 
 #endif // ENUMATTR_TD
diff --git a/mlir/include/mlir/TableGen/EnumInfo.h b/mlir/include/mlir/TableGen/EnumInfo.h
index 5bc7ffb6a8a35..aea76c01e7e7a 100644
--- a/mlir/include/mlir/TableGen/EnumInfo.h
+++ b/mlir/include/mlir/TableGen/EnumInfo.h
@@ -84,6 +84,9 @@ class EnumInfo {
   // Returns the description of the enum.
   StringRef getDescription() const;
 
+  // Returns the bitwidth of the enum.
+  int64_t getBitwidth() const;
+
   // Returns the underlying type.
   StringRef getUnderlyingType() const;
 
@@ -119,6 +122,7 @@ class EnumInfo {
   // Only applicable for bit enums.
 
   bool printBitEnumPrimaryGroups() const;
+  bool printBitEnumQuoted() const;
 
   // Returns the TableGen definition this EnumAttrCase was constructed from.
   const llvm::Record &getDef() const;
diff --git a/mlir/lib/TableGen/EnumInfo.cpp b/mlir/lib/TableGen/EnumInfo.cpp
index 9f491d30f0e7f..6128c53557cc4 100644
--- a/mlir/lib/TableGen/EnumInfo.cpp
+++ b/mlir/lib/TableGen/EnumInfo.cpp
@@ -18,8 +18,8 @@ using llvm::Init;
 using llvm::Record;
 
 EnumCase::EnumCase(const Record *record) : def(record) {
-  assert(def->isSubClassOf("EnumAttrCaseInfo") &&
-         "must be subclass of TableGen 'EnumAttrCaseInfo' class");
+  assert(def->isSubClassOf("EnumCase") &&
+         "must be subclass of TableGen 'EnumCase' class");
 }
 
 EnumCase::EnumCase(const DefInit *init) : EnumCase(init->getDef()) {}
@@ -35,8 +35,8 @@ int64_t EnumCase::getValue() const { return def->getValueAsInt("value"); }
 const Record &EnumCase::getDef() const { return *def; }
 
 EnumInfo::EnumInfo(const Record *record) : def(record) {
-  assert(isSubClassOf("EnumAttrInfo") &&
-         "must be subclass of TableGen 'EnumAttrInfo' class");
+  assert(isSubClassOf("EnumInfo") &&
+         "must be subclass of TableGen 'EnumInfo' class");
 }
 
 EnumInfo::EnumInfo(const Record &record) : EnumInfo(&record) {}
@@ -55,7 +55,7 @@ std::optional<Attribute> EnumInfo::asEnumAttr() const {
   return std::nullopt;
 }
 
-bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumAttr"); }
+bool EnumInfo::isBitEnum() const { return isSubClassOf("BitEnumBase"); }
 
 StringRef EnumInfo::getEnumClassName() const {
   return def->getValueAsString("className");
@@ -73,6 +73,8 @@ StringRef EnumInfo::getCppNamespace() const {
   return def->getValueAsString("cppNamespace");
 }
 
+int64_t EnumInfo::getBitwidth() const { return def->getValueAsInt("bitwidth"); }
+
 StringRef EnumInfo::getUnderlyingType() const {
   return def->getValueAsString("underlyingType");
 }
@@ -127,4 +129,8 @@ bool EnumInfo::printBitEnumPrimaryGroups() const {
   return def->getValueAsBit("printBitEnumPrimaryGroups");
 }
 
+bool EnumInfo::printBitEnumQuoted() const {
+  return def->getValueAsBit("printBitEnumQuoted");
+}
+
 const Record &EnumInfo::getDef() const { return *def; }
diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp
index 73e2803c21dae..d83df3e415c36 100644
--- a/mlir/lib/TableGen/Pattern.cpp
+++ b/mlir/lib/TableGen/Pattern.cpp
@@ -57,7 +57,7 @@ bool DagLeaf::isNativeCodeCall() const {
 
 bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); }
 
-bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumAttrCaseInfo"); }
+bool DagLeaf::isEnumCase() const { return isSubClassOf("EnumCase"); }
 
 bool DagLeaf::isStringAttr() const { return isa<llvm::StringInit>(def); }
 
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index 74dd862ce8fb2..7caea3920255a 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -428,7 +428,7 @@ module {
 
 module {
   "llvm.func"() ({
-  // expected-error @below {{invalid Calling Conventions specification: cc_12}}
+  // expected-error @below {{expected one of [ccc, fastcc, coldcc, cc_10, cc_11, anyregcc, preserve_mostcc, preserve_allcc, swiftcc, cxx_fast_tlscc, tailcc, cfguard_checkcc, swifttailcc, x86_stdcallcc, x86_fastcallcc, arm_apcscc, arm_aapcscc, arm_aapcs_vfpcc, msp430_intrcc, x86_thiscallcc, ptx_kernelcc, ptx_devicecc, spir_funccc, spir_kernelcc, intel_ocl_bicc, x86_64_sysvcc, win64cc, x86_vectorcallcc, hhvmcc, hhvm_ccc, x86_intrcc, avr_intrcc, avr_builtincc, amdgpu_vscc, amdgpu_gscc, amdgpu_cscc, amdgpu_kernelcc, x86_regcallcc, amdgpu_hscc, msp430_builtincc, amdgpu_lscc, amdgpu_escc, aarch64_vectorcallcc, aarch64_sve_vectorcallcc, wasm_emscripten_invokecc, amdgpu_gfxcc, m68k_intrcc] for Calling Conventions, got: cc_12}}
   // expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
   }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
 }
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 5a005a393d8ac..4f280bde1aecc 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -535,7 +535,7 @@ func.func @allowed_cases_pass() {
 // -----
 
 func.func @disallowed_case_sticky_fail() {
-  // expected-error at +2 {{expected test::TestBitEnum to be one of: read, write, execute}}
+  // expected-error at +2 {{expected one of [read, write, execute] for a test bit enum, got: sticky}}
   // expected-error at +1 {{failed to parse TestBitEnumAttr}}
   "test.op_with_bit_enum"() {value = #test.bit_enum<sticky>} : () -> ()
 }
diff --git a/mlir/test/lib/Dialect/Test/TestEnumDefs.td b/mlir/test/lib/Dialect/Test/TestEnumDefs.td
index 1ddfca0b22315..7441ea5a9726b 100644
--- a/mlir/test/lib/Dialect/Test/TestEnumDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestEnumDefs.td
@@ -42,11 +42,10 @@ def TestEnum
   let cppNamespace = "test";
 }
 
-def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
-    I32EnumAttrCase<"a", 0>,
-    I32EnumAttrCase<"b", 1>
+def TestSimpleEnum : I32Enum<"SimpleEnum", "", [
+    I32EnumCase<"a", 0>,
+    I32EnumCase<"b", 1>
   ]> {
-  let genSpecializedAttr = 0;
   let cppNamespace = "::test";
 }
 
@@ -56,24 +55,22 @@ def TestSimpleEnum : I32EnumAttr<"SimpleEnum", "", [
 
 // Define the C++ enum.
 def TestBitEnum
-    : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
-        I32BitEnumAttrCaseBit<"Read", 0, "read">,
-        I32BitEnumAttrCaseBit<"Write", 1, "write">,
-        I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
+    : I32BitEnum<"TestBitEnum", "a test bit enum", [
+        I32BitEnumCaseBit<"Read", 0, "read">,
+        I32BitEnumCaseBit<"Write", 1, "write">,
+        I32BitEnumCaseBit<"Execute", 2, "execute">,
       ]> {
-  let genSpecializedAttr = 0;
   let cppNamespace = "test";
   let separator = ", ";
 }
 
 // Define an enum with a different separator
 def TestBitEnumVerticalBar
-    : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
-        I32BitEnumAttrCaseBit<"User", 0, "user">,
-        I32BitEnumAttrCaseBit<"Group", 1, "group">,
-        I32BitEnumAttrCaseBit<"Other", 2, "other">,
+    : I32BitEnum<"TestBitEnumVerticalBar", "another test bit enum", [
+        I32BitEnumCaseBit<"User", 0, "user">,
+        I32BitEnumCaseBit<"Group", 1, "group">,
+        I32BitEnumCaseBit<"Other", 2, "other">,
       ]> {
-  let genSpecializedAttr = 0;
   let cppNamespace = "test";
   let separator = " | ";
 }
diff --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td
index c3a768e42236c..8489cff7c429d 100644
--- a/mlir/test/mlir-tblgen/enums-gen.td
+++ b/mlir/test/mlir-tblgen/enums-gen.td
@@ -5,12 +5,12 @@ include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 
 // Test bit enums
-def None: I32BitEnumAttrCaseNone<"None", "none">;
-def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
-def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
-def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
+def None: I32BitEnumCaseNone<"None", "none">;
+def Bit0: I32BitEnumCaseBit<"Bit0", 0, "tagged">;
+def Bit1: I32BitEnumCaseBit<"Bit1", 1>;
+def Bit2: I32BitEnumCaseBit<"Bit2", 2>;
 def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
-def BitGroup: I32BitEnumAttrCaseGroup<"BitGroup", [
+def BitGroup: BitEnumCaseGroup<"BitGroup", [
   Bit0, Bit1
 ]>;
 
@@ -42,7 +42,7 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
 // DECL:     // Symbolize the keyword.
 // DECL:     if (::std::optional<::MyBitEnum> attr = ::symbolizeEnum<::MyBitEnum>(enumKeyword))
 // DECL:       return *attr;
-// DECL:     return parser.emitError(loc, "invalid An example bit enum specification: ") << enumKeyword;
+// DECL:     return parser.emitError(loc, "expected one of [none, tagged, Bit1, Bit2, Bit3, BitGroup] for An example bit enum, got: ") << enumKeyword;
 // DECL:   }
 
 // DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) {
@@ -73,7 +73,7 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
 
 // Test enum printer generation for non non-keyword enums.
 
-def NonKeywordBit: I32BitEnumAttrCaseBit<"Bit0", 0, "tag-ged">;
+def NonKeywordBit: I32BitEnumCaseBit<"Bit0", 0, "tag-ged">;
 def MyMixedNonKeywordBitEnum: I32BitEnumAttr<"MyMixedNonKeywordBitEnum", "An example bit enum", [
     NonKeywordBit,
     Bit1
@@ -101,3 +101,32 @@ def MyNonKeywordBitEnum: I32BitEnumAttr<"MyNonKeywordBitEnum", "An example bit e
 // DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonKeywordBitEnum value) {
 // DECL:   auto valueStr = stringifyEnum(value);
 // DECL:   return p << '"' << valueStr << '"';
+
+def MyNonQuotedPrintBitEnum
+  : I32BitEnum<"MyNonQuotedPrintBitEnum", "Example new-style bit enum",
+    [None, Bit0, Bit1, Bit2, Bit3, BitGroup]>;
+
+// DECL: struct FieldParser<::MyNonQuotedPrintBitEnum, ::MyNonQuotedPrintBitEnum> {
+// DECL:   template <typename ParserT>
+// DECL:   static FailureOr<::MyNonQuotedPrintBitEnum> parse(ParserT &parser) {
+// DECL:     ::MyNonQuotedPrintBitEnum flags = {};
+// DECL:     do {
+  // DECL:     // Parse the keyword containing a part of the enum.
+// DECL:       ::llvm::StringRef enumKeyword;
+// DECL:       auto loc = parser.getCurrentLocation();
+// DECL:       if (failed(parser.parseOptionalKeyword(&enumKeyword))) {
+// DECL:         return parser.emitError(loc, "expected keyword for Example new-style bit enum");
+// DECL:       }
+// DECL:       // Symbolize the keyword.
+// DECL:       if (::std::optional<::MyNonQuotedPrintBitEnum> flag = ::symbolizeEnum<::MyNonQuotedPrintBitEnum>(enumKeyword))
+// DECL:         flags = flags | *flag;
+// DECL:       } else {
+// DECL:         return parser.emitError(loc, "expected one of [none, tagged, Bit1, Bit2, Bit3, BitGroup] for Example new-style bit enum, got: ") << enumKeyword;
+// DECL:       }
+// DECL:     } while (::mlir::succeeded(parser.parseOptionalVerticalBar()));
+// DECL:     return flags;
+// DECL:   }
+
+// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonQuotedPrintBitEnum value) {
+// DECL:   auto valueStr = stringifyEnum(value);
+// DECL-NEXT:   return p << valueStr;
diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
index 5d4d9e90fff67..8e2d6114e48eb 100644
--- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -85,17 +85,6 @@ static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
   os << "\n";
 }
 
-/// Attempts to extract the bitwidth B from string "uintB_t" describing the
-/// type. This bitwidth information is not readily available in ODS. Returns
-/// `false` on success, `true` on failure.
-static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
-  if (!uintType.consume_front("uint"))
-    return true;
-  if (!uintType.consume_back("_t"))
-    return true;
-  return uintType.getAsInteger(/*Radix=*/10, bitwidth);
-}
-
 /// Emits an attribute builder for the given enum attribute to support automatic
 /// conversion between enum values and attributes in Python. Returns
 /// `false` on success, `true` on failure.
@@ -104,12 +93,7 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
   if (!enumAttrInfo)
     return false;
 
-  int64_t bitwidth;
-  if (extractUIntBitwidth(enumInfo.getUnderlyingType(), bitwidth)) {
-    llvm::errs() << "failed to identify bitwidth of "
-                 << enumInfo.getUnderlyingType();
-    return true;
-  }
+  int64_t bitwidth = enumInfo.getBitwidth();
   os << formatv("@register_attribute_builder(\"{0}\")\n",
                 enumAttrInfo->getAttrDefName());
   os << formatv("def _{0}(x, context):\n",
@@ -140,7 +124,7 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
 static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
   os << fileHeader;
   for (const Record *it :
-       records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
+       records.getAllDerivedDefinitionsIfDefined("EnumInfo")) {
     EnumInfo enumInfo(*it);
     emitEnumClass(enumInfo, os);
     emitAttributeBuilder(enumInfo, os);
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index fa6fad156b747..9941a203bc5cb 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -77,11 +77,22 @@ static void emitParserPrinter(const EnumInfo &enumInfo, StringRef qualName,
 
   // Check which cases shouldn't be printed using a keyword.
   llvm::BitVector nonKeywordCases(cases.size());
-  for (auto [index, caseVal] : llvm::enumerate(cases))
-    if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
-      nonKeywordCases.set(index);
-
-  // Generate the parser and the start of the printer for the enum.
+  std::string casesList;
+  llvm::raw_string_ostream caseListOs(casesList);
+  caseListOs << "[";
+  llvm::interleaveComma(llvm::enumerate(cases), caseListOs,
+                        [&](auto enumerant) {
+                          StringRef name = enumerant.value().getStr();
+                          if (!mlir::tblgen::canFormatStringAsKeyword(name)) {
+                            nonKeywordCases.set(enumerant.index());
+                            caseListOs << "\\\"" << name << "\\\"";
+                          }
+                          caseListOs << name;
+                        });
+  caseListOs << "]";
+
+  // Generate the parser and the start of the printer for the enum, excluding
+  // non-quoted bit enums.
   const char *parsedAndPrinterStart = R"(
 namespace mlir {
 template <typename T, typename>
@@ -100,7 +111,7 @@ struct FieldParser<{0}, {0}> {{
     // Symbolize the keyword.
     if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
       return *attr;
-    return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
+    return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
   }
 };
 
@@ -121,7 +132,7 @@ struct FieldParser<std::optional<{0}>, std::optional<{0}>> {{
     // Symbolize the keyword.
     if (::std::optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
       return attr;
-    return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
+    return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
   }
 };
 } // namespace mlir
@@ -131,8 +142,94 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
   auto valueStr = stringifyEnum(value);
 )";
 
-  os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
-                enumInfo.getSummary());
+  const char *parsedAndPrinterStartUnquotedBitEnum = R"(
+  namespace mlir {
+  template <typename T, typename>
+  struct FieldParser;
+
+  template<>
+  struct FieldParser<{0}, {0}> {{
+    template <typename ParserT>
+    static FailureOr<{0}> parse(ParserT &parser) {{
+      {0} flags = {{};
+      do {{
+        // Parse the keyword containing a part of the enum.
+        ::llvm::StringRef enumKeyword;
+        auto loc = parser.getCurrentLocation();
+        if (failed(parser.parseOptionalKeyword(&enumKeyword))) {{
+          return parser.emitError(loc, "expected keyword for {2}");
+        }
+
+        // Symbolize the keyword.
+        if (::std::optional<{0}> flag = {1}::symbolizeEnum<{0}>(enumKeyword)) {{
+          flags = flags | *flag;
+        } else {{
+          return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
+        }
+      } while (::mlir::succeeded(parser.{5}()));
+      return flags;
+    }
+  };
+
+  /// Support for std::optional, useful in attribute/type definition where the enum is
+  /// used as:
+  ///
+  ///    let parameters = (ins OptionalParameter<"std::optional<TheEnumName>">:$value);
+  template<>
+  struct FieldParser<std::optional<{0}>, std::optional<{0}>> {{
+    template <typename ParserT>
+    static FailureOr<std::optional<{0}>> parse(ParserT &parser) {{
+      {0} flags = {{};
+      bool firstIter = true;
+      do {{
+        // Parse the keyword containing a part of the enum.
+        ::llvm::StringRef enumKeyword;
+        auto loc = parser.getCurrentLocation();
+        if (failed(parser.parseOptionalKeyword(&enumKeyword))) {{
+          if (firstIter)
+            return std::optional<{0}>{{};
+          return parser.emitError(loc, "expected keyword for {2} after '{4}'");
+        }
+        firstIter = false;
+
+        // Symbolize the keyword.
+        if (::std::optional<{0}> flag = {1}::symbolizeEnum<{0}>(enumKeyword)) {{
+          flags = flags | *flag;
+        } else {{
+          return parser.emitError(loc, "expected one of {3} for {2}, got: ") << enumKeyword;
+        }
+      } while(::mlir::succeeded(parser.{5}()));
+      return std::optional<{0}>{{flags};
+    }
+  };
+  } // namespace mlir
+
+  namespace llvm {
+  inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
+    auto valueStr = stringifyEnum(value);
+  )";
+
+  bool isNewStyleBitEnum =
+      enumInfo.isBitEnum() && !enumInfo.printBitEnumQuoted();
+
+  if (isNewStyleBitEnum) {
+    if (nonKeywordCases.any())
+      return PrintFatalError(
+          "bit enum " + qualName +
+          " cannot be printed unquoted with cases that cannot be keywords");
+    StringRef separator = enumInfo.getDef().getValueAsString("separator");
+    StringRef parseSeparatorFn =
+        llvm::StringSwitch<StringRef>(separator.trim())
+            .Case("|", "parseOptionalVerticalBar")
+            .Case(",", "parseOptionalComma")
+            .Default("error, enum seperator must be '|' or ','");
+    os << formatv(parsedAndPrinterStartUnquotedBitEnum, qualName, cppNamespace,
+                  enumInfo.getSummary(), casesList, separator,
+                  parseSeparatorFn);
+  } else {
+    os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
+                  enumInfo.getSummary(), casesList);
+  }
 
   // If all cases require a string, always wrap.
   if (nonKeywordCases.all()) {
@@ -160,7 +257,10 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
 
     // If this is a bit enum, conservatively print the string form if the value
     // is not a power of two (i.e. not a single bit case) and not a known case.
-  } else if (enumInfo.isBitEnum()) {
+    // Only do this if we're using the old-style parser that parses the enum as
+    // one keyword, as opposed to the new form, where we can print the value
+    // as-is.
+  } else if (enumInfo.isBitEnum() && !isNewStyleBitEnum) {
     // Process the known multi-bit cases that use valid keywords.
     SmallVector<EnumCase *> validMultiBitCases;
     for (auto [index, caseVal] : llvm::enumerate(cases)) {
@@ -670,7 +770,7 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("Enum Utility Declarations", os, records);
 
   for (const Record *def :
-       records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
+       records.getAllDerivedDefinitionsIfDefined("EnumInfo"))
     emitEnumDecl(*def, os);
 
   return false;
@@ -708,7 +808,7 @@ static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("Enum Utility Definitions", os, records);
 
   for (const Record *def :
-       records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo"))
+       records.getAllDerivedDefinitionsIfDefined("EnumInfo"))
     emitEnumDef(*def, os);
 
   return false;
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index f53aebb302dc9..077f9d1ea2b13 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -406,7 +406,7 @@ static void emitEnumDoc(const EnumInfo &def, raw_ostream &os) {
 
 static void emitEnumDoc(const RecordKeeper &records, raw_ostream &os) {
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
-  for (const Record *def : records.getAllDerivedDefinitions("EnumAttrInfo"))
+  for (const Record *def : records.getAllDerivedDefinitions("EnumInfo"))
     emitEnumDoc(EnumInfo(def), os);
 }
 
@@ -526,7 +526,7 @@ static bool emitDialectDoc(const RecordKeeper &records, raw_ostream &os) {
   auto typeDefs = records.getAllDerivedDefinitionsIfDefined("DialectType");
   auto typeDefDefs = records.getAllDerivedDefinitionsIfDefined("TypeDef");
   auto attrDefDefs = records.getAllDerivedDefinitionsIfDefined("AttrDef");
-  auto enumDefs = records.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
+  auto enumDefs = records.getAllDerivedDefinitionsIfDefined("EnumInfo");
 
   std::vector<Attribute> dialectAttrs;
   std::vector<AttrDef> dialectAttrDefs;
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 7a6189c09f426..f94ed17aeb4e0 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -455,7 +455,7 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Enum Availability Declarations", os,
                              records);
 
-  auto defs = records.getAllDerivedDefinitions("EnumAttrInfo");
+  auto defs = records.getAllDerivedDefinitions("EnumInfo");
   for (const auto *def : defs)
     emitEnumDecl(*def, os);
 
@@ -487,7 +487,7 @@ static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os,
                              records);
 
-  auto defs = records.getAllDerivedDefinitions("EnumAttrInfo");
+  auto defs = records.getAllDerivedDefinitions("EnumInfo");
   for (const auto *def : defs)
     emitEnumDef(*def, os);
 
@@ -1262,7 +1262,7 @@ static void emitEnumGetAttrNameFnDefn(const EnumInfo &enumInfo,
 static bool emitAttrUtils(const RecordKeeper &records, raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Attribute Utilities", os, records);
 
-  auto defs = records.getAllDerivedDefinitions("EnumAttrInfo");
+  auto defs = records.getAllDerivedDefinitions("EnumInfo");
   os << "#ifndef MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
   os << "#define MLIR_DIALECT_SPIRV_IR_ATTR_UTILS_H_\n";
   emitEnumGetAttrNameFnDecl(os);
diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index 99ed3489b4cbd..d2d0b410f52df 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -288,11 +288,11 @@ def get_availability_spec(enum_case, for_op, for_cap):
 
 
 def gen_operand_kind_enum_attr(operand_kind):
-    """Generates the TableGen EnumAttr definition for the given operand kind.
+    """Generates the TableGen EnumInfo definition for the given operand kind.
 
     Returns:
       - The operand kind's name
-      - A string containing the TableGen EnumAttr definition
+      - A string containing the TableGen EnumInfo definition
     """
     if "enumerants" not in operand_kind:
         return "", ""



More information about the Mlir-commits mailing list