[Mlir-commits] [mlir] 33185e6 - [mlir] Add ODS support for enum attributes with grouped bit cases

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 26 13:01:06 PST 2022


Author: Jeremy Furtek
Date: 2022-01-26T21:01:01Z
New Revision: 33185e66f24187e9573154090556967f71265b7c

URL: https://github.com/llvm/llvm-project/commit/33185e66f24187e9573154090556967f71265b7c
DIFF: https://github.com/llvm/llvm-project/commit/33185e66f24187e9573154090556967f71265b7c.diff

LOG: [mlir] Add ODS support for enum attributes with grouped bit cases

This diff modifies the tablegen specification and code generation for
BitEnumAttr attributes in MLIR Operation Definition Specification (ODS) files.
Specifically:

- there is a new tablegen class for "none" values (i.e. no bits set)
- single-bit enum cases are specified via bit index (i.e. [0, 31]) instead of
  the resulting enum integer value
- there is a new tablegen class to represent a "grouped" bitwise OR of other
  enum values

This diff is intended as an initial step towards improving "fastmath"
optimization support in MLIR, to allow more precise control of whether certain
floating point optimizations are applied in MLIR passes. "Fast" math options
for floating point MLIR operations would (following subsequent RFC and
discussion) be specified by using the improved enum bit support in this diff.
For example, a "fast" enum value would act as an alias for a group of other
cases (e.g. finite-math-only, no-signed-zeros, etc.), in a way that is similar
to support in C/C++ compilers (clang, gcc).

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D117029

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/tools/mlir-tblgen/EnumsGen.cpp
    mlir/unittests/TableGen/EnumsGenTest.cpp
    mlir/unittests/TableGen/enums.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 9de4671727028..f26afa5666e23 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1245,7 +1245,8 @@ several mechanisms: `StrEnumAttr`, `IntEnumAttr`, and `BitEnumAttr`.
     [`StringAttr`][StringAttr] in the op.
 *   `IntEnumAttr`: each enum case is an integer, the attribute is stored as a
     [`IntegerAttr`][IntegerAttr] in the op.
-*   `BitEnumAttr`: each enum case is a bit, the attribute is stored as a
+*   `BitEnumAttr`: each enum case is a either the empty case, a single bit,
+    or a group of single bits, and the attribute is stored as a
     [`IntegerAttr`][IntegerAttr] in the op.
 
 All these `*EnumAttr` attributes require fully specifying all of the allowed
@@ -1349,13 +1350,14 @@ llvm::Optional<MyIntEnum> symbolizeMyIntEnum(uint32_t value) {
 Similarly for the following `BitEnumAttr` definition:
 
 ```tablegen
-def None: BitEnumAttrCase<"None", 0x0000>;
-def Bit1: BitEnumAttrCase<"Bit1", 0x0001>;
-def Bit2: BitEnumAttrCase<"Bit2", 0x0002>;
-def Bit3: BitEnumAttrCase<"Bit3", 0x0004>;
+def None: BitEnumAttrCaseNone<"None">;
+def Bit0: BitEnumAttrCaseBit<"Bit0", 0>;
+def Bit1: BitEnumAttrCaseBit<"Bit1", 1>;
+def Bit2: BitEnumAttrCaseBit<"Bit2", 2>;
+def Bit3: BitEnumAttrCaseBit<"Bit3", 3>;
 
 def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum",
-                           [None, Bit1, Bit2, Bit3]>;
+                           [None, Bit0, Bit1, Bit2, Bit3]>;
 ```
 
 We can have:
@@ -1364,9 +1366,10 @@ We can have:
 // An example bit enum
 enum class MyBitEnum : uint32_t {
   None = 0,
-  Bit1 = 1,
-  Bit2 = 2,
-  Bit3 = 4,
+  Bit0 = 1,
+  Bit1 = 2,
+  Bit2 = 4,
+  Bit3 = 8,
 };
 
 llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t);
@@ -1407,15 +1410,15 @@ template<> struct DenseMapInfo<::MyBitEnum> {
 ```c++
 std::string stringifyMyBitEnum(MyBitEnum symbol) {
   auto val = static_cast<uint32_t>(symbol);
+  assert(15u == (15u | val) && "invalid bits set in bit enum");
   // Special case for all bits unset.
   if (val == 0) return "None";
-
   llvm::SmallVector<llvm::StringRef, 2> strs;
-  if (1u & val) { strs.push_back("Bit1"); val &= ~1u; }
-  if (2u & val) { strs.push_back("Bit2"); val &= ~2u; }
-  if (4u & val) { strs.push_back("Bit3"); val &= ~4u; }
-
-  if (val) return "";
+  if (1u == (1u & val)) { strs.push_back("Bit0"); }
+  if (2u == (2u & val)) { strs.push_back("Bit1"); }
+  if (4u == (4u & val)) { strs.push_back("Bit2"); }
+  if (8u == (8u & val)) { strs.push_back("Bit3"); }
+  
   return llvm::join(strs, "|");
 }
 
@@ -1429,9 +1432,10 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(llvm::StringRef str) {
   uint32_t val = 0;
   for (auto symbol : symbols) {
     auto bit = llvm::StringSwitch<llvm::Optional<uint32_t>>(symbol)
-      .Case("Bit1", 1)
-      .Case("Bit2", 2)
-      .Case("Bit3", 4)
+      .Case("Bit0", 1)
+      .Case("Bit1", 2)
+      .Case("Bit2", 4)
+      .Case("Bit3", 8)
       .Default(llvm::None);
     if (bit) { val |= *bit; } else { return llvm::None; }
   }
@@ -1442,7 +1446,7 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
   // Special case for all bits unset.
   if (value == 0) return MyBitEnum::None;
 
-  if (value & ~(1u | 2u | 4u)) return llvm::None;
+  if (value & ~(1u | 2u | 4u | 8u)) return llvm::None;
   return static_cast<MyBitEnum>(value);
 }
 ```

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 0a7f4c9ff82bc..90cd333de468b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -21,14 +21,14 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
-def FMFnnan     : BitEnumAttrCase<"nnan", 0x1>;
-def FMFninf     : BitEnumAttrCase<"ninf", 0x2>;
-def FMFnsz      : BitEnumAttrCase<"nsz", 0x4>;
-def FMFarcp     : BitEnumAttrCase<"arcp", 0x8>;
-def FMFcontract : BitEnumAttrCase<"contract", 0x10>;
-def FMFafn      : BitEnumAttrCase<"afn", 0x20>;
-def FMFreassoc  : BitEnumAttrCase<"reassoc", 0x40>;
-def FMFfast     : BitEnumAttrCase<"fast", 0x80>;
+def FMFnnan     : BitEnumAttrCaseBit<"nnan", 0>;
+def FMFninf     : BitEnumAttrCaseBit<"ninf", 1>;
+def FMFnsz      : BitEnumAttrCaseBit<"nsz", 2>;
+def FMFarcp     : BitEnumAttrCaseBit<"arcp", 3>;
+def FMFcontract : BitEnumAttrCaseBit<"contract", 4>;
+def FMFafn      : BitEnumAttrCaseBit<"afn", 5>;
+def FMFreassoc  : BitEnumAttrCaseBit<"reassoc", 6>;
+def FMFfast     : BitEnumAttrCaseBit<"fast", 7>;
 
 def FastmathFlags_DoNotUse : BitEnumAttr<
     "FastmathFlags",

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 5431baa0f9f41..577d4fca9f352 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -3082,12 +3082,12 @@ def SPV_ExecutionModelAttr :
       SPV_EM_AnyHitKHR, SPV_EM_ClosestHitKHR, SPV_EM_MissKHR, SPV_EM_CallableKHR
     ]>;
 
-def SPV_FC_None         : BitEnumAttrCase<"None", 0x0000>;
-def SPV_FC_Inline       : BitEnumAttrCase<"Inline", 0x0001>;
-def SPV_FC_DontInline   : BitEnumAttrCase<"DontInline", 0x0002>;
-def SPV_FC_Pure         : BitEnumAttrCase<"Pure", 0x0004>;
-def SPV_FC_Const        : BitEnumAttrCase<"Const", 0x0008>;
-def SPV_FC_OptNoneINTEL : BitEnumAttrCase<"OptNoneINTEL", 0x10000> {
+def SPV_FC_None         : BitEnumAttrCaseNone<"None">;
+def SPV_FC_Inline       : BitEnumAttrCaseBit<"Inline", 0>;
+def SPV_FC_DontInline   : BitEnumAttrCaseBit<"DontInline", 1>;
+def SPV_FC_Pure         : BitEnumAttrCaseBit<"Pure", 2>;
+def SPV_FC_Const        : BitEnumAttrCaseBit<"Const", 3>;
+def SPV_FC_OptNoneINTEL : BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
   list<Availability> availability = [
     Capability<[SPV_C_OptNoneINTEL]>
   ];
@@ -3366,62 +3366,62 @@ def SPV_ImageFormatAttr :
       SPV_IF_R8ui, SPV_IF_R64ui, SPV_IF_R64i
     ]>;
 
-def SPV_IO_None               : BitEnumAttrCase<"None", 0x0000>;
-def SPV_IO_Bias               : BitEnumAttrCase<"Bias", 0x0001> {
+def SPV_IO_None               : BitEnumAttrCaseNone<"None">;
+def SPV_IO_Bias               : BitEnumAttrCaseBit<"Bias", 0> {
   list<Availability> availability = [
     Capability<[SPV_C_Shader]>
   ];
 }
-def SPV_IO_Lod                : BitEnumAttrCase<"Lod", 0x0002>;
-def SPV_IO_Grad               : BitEnumAttrCase<"Grad", 0x0004>;
-def SPV_IO_ConstOffset        : BitEnumAttrCase<"ConstOffset", 0x0008>;
-def SPV_IO_Offset             : BitEnumAttrCase<"Offset", 0x0010> {
+def SPV_IO_Lod                : BitEnumAttrCaseBit<"Lod", 1>;
+def SPV_IO_Grad               : BitEnumAttrCaseBit<"Grad", 2>;
+def SPV_IO_ConstOffset        : BitEnumAttrCaseBit<"ConstOffset", 3>;
+def SPV_IO_Offset             : BitEnumAttrCaseBit<"Offset", 4> {
   list<Availability> availability = [
     Capability<[SPV_C_ImageGatherExtended]>
   ];
 }
-def SPV_IO_ConstOffsets       : BitEnumAttrCase<"ConstOffsets", 0x0020> {
+def SPV_IO_ConstOffsets       : BitEnumAttrCaseBit<"ConstOffsets", 5> {
   list<Availability> availability = [
     Capability<[SPV_C_ImageGatherExtended]>
   ];
 }
-def SPV_IO_Sample             : BitEnumAttrCase<"Sample", 0x0040>;
-def SPV_IO_MinLod             : BitEnumAttrCase<"MinLod", 0x0080> {
+def SPV_IO_Sample             : BitEnumAttrCaseBit<"Sample", 6>;
+def SPV_IO_MinLod             : BitEnumAttrCaseBit<"MinLod", 7> {
   list<Availability> availability = [
     Capability<[SPV_C_MinLod]>
   ];
 }
-def SPV_IO_MakeTexelAvailable : BitEnumAttrCase<"MakeTexelAvailable", 0x0100> {
+def SPV_IO_MakeTexelAvailable : BitEnumAttrCaseBit<"MakeTexelAvailable", 8> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_IO_MakeTexelVisible   : BitEnumAttrCase<"MakeTexelVisible", 0x0200> {
+def SPV_IO_MakeTexelVisible   : BitEnumAttrCaseBit<"MakeTexelVisible", 9> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_IO_NonPrivateTexel    : BitEnumAttrCase<"NonPrivateTexel", 0x0400> {
+def SPV_IO_NonPrivateTexel    : BitEnumAttrCaseBit<"NonPrivateTexel", 10> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_IO_VolatileTexel      : BitEnumAttrCase<"VolatileTexel", 0x0800> {
+def SPV_IO_VolatileTexel      : BitEnumAttrCaseBit<"VolatileTexel", 11> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_IO_SignExtend         : BitEnumAttrCase<"SignExtend", 0x1000> {
+def SPV_IO_SignExtend         : BitEnumAttrCaseBit<"SignExtend", 12> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_IO_Offsets            : BitEnumAttrCase<"Offsets", 0x10000>;
-def SPV_IO_ZeroExtend         : BitEnumAttrCase<"ZeroExtend", 0x2000> {
+def SPV_IO_Offsets            : BitEnumAttrCaseBit<"Offsets", 16>;
+def SPV_IO_ZeroExtend         : BitEnumAttrCaseBit<"ZeroExtend", 13> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
@@ -3457,87 +3457,87 @@ def SPV_LinkageTypeAttr :
       SPV_LT_Export, SPV_LT_Import, SPV_LT_LinkOnceODR
     ]>;
 
-def SPV_LC_None                      : BitEnumAttrCase<"None", 0x0000>;
-def SPV_LC_Unroll                    : BitEnumAttrCase<"Unroll", 0x0001>;
-def SPV_LC_DontUnroll                : BitEnumAttrCase<"DontUnroll", 0x0002>;
-def SPV_LC_DependencyInfinite        : BitEnumAttrCase<"DependencyInfinite", 0x0004> {
+def SPV_LC_None                      : BitEnumAttrCaseNone<"None">;
+def SPV_LC_Unroll                    : BitEnumAttrCaseBit<"Unroll", 0>;
+def SPV_LC_DontUnroll                : BitEnumAttrCaseBit<"DontUnroll", 1>;
+def SPV_LC_DependencyInfinite        : BitEnumAttrCaseBit<"DependencyInfinite", 2> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_1>
   ];
 }
-def SPV_LC_DependencyLength          : BitEnumAttrCase<"DependencyLength", 0x0008> {
+def SPV_LC_DependencyLength          : BitEnumAttrCaseBit<"DependencyLength", 3> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_1>
   ];
 }
-def SPV_LC_MinIterations             : BitEnumAttrCase<"MinIterations", 0x0010> {
+def SPV_LC_MinIterations             : BitEnumAttrCaseBit<"MinIterations", 4> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_MaxIterations             : BitEnumAttrCase<"MaxIterations", 0x0020> {
+def SPV_LC_MaxIterations             : BitEnumAttrCaseBit<"MaxIterations", 5> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_IterationMultiple         : BitEnumAttrCase<"IterationMultiple", 0x0040> {
+def SPV_LC_IterationMultiple         : BitEnumAttrCaseBit<"IterationMultiple", 6> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_PeelCount                 : BitEnumAttrCase<"PeelCount", 0x0080> {
+def SPV_LC_PeelCount                 : BitEnumAttrCaseBit<"PeelCount", 7> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_PartialCount              : BitEnumAttrCase<"PartialCount", 0x0100> {
+def SPV_LC_PartialCount              : BitEnumAttrCaseBit<"PartialCount", 8> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_InitiationIntervalINTEL   : BitEnumAttrCase<"InitiationIntervalINTEL", 0x10000> {
+def SPV_LC_InitiationIntervalINTEL   : BitEnumAttrCaseBit<"InitiationIntervalINTEL", 16> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_LoopCoalesceINTEL         : BitEnumAttrCase<"LoopCoalesceINTEL", 0x100000> {
+def SPV_LC_LoopCoalesceINTEL         : BitEnumAttrCaseBit<"LoopCoalesceINTEL", 20> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_MaxConcurrencyINTEL       : BitEnumAttrCase<"MaxConcurrencyINTEL", 0x20000> {
+def SPV_LC_MaxConcurrencyINTEL       : BitEnumAttrCaseBit<"MaxConcurrencyINTEL", 17> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_MaxInterleavingINTEL      : BitEnumAttrCase<"MaxInterleavingINTEL", 0x200000> {
+def SPV_LC_MaxInterleavingINTEL      : BitEnumAttrCaseBit<"MaxInterleavingINTEL", 21> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_DependencyArrayINTEL      : BitEnumAttrCase<"DependencyArrayINTEL", 0x40000> {
+def SPV_LC_DependencyArrayINTEL      : BitEnumAttrCaseBit<"DependencyArrayINTEL", 18> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_SpeculatedIterationsINTEL : BitEnumAttrCase<"SpeculatedIterationsINTEL", 0x400000> {
+def SPV_LC_SpeculatedIterationsINTEL : BitEnumAttrCaseBit<"SpeculatedIterationsINTEL", 22> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_PipelineEnableINTEL       : BitEnumAttrCase<"PipelineEnableINTEL", 0x80000> {
+def SPV_LC_PipelineEnableINTEL       : BitEnumAttrCaseBit<"PipelineEnableINTEL", 19> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_NoFusionINTEL             : BitEnumAttrCase<"NoFusionINTEL", 0x800000> {
+def SPV_LC_NoFusionINTEL             : BitEnumAttrCaseBit<"NoFusionINTEL", 23> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
@@ -3555,23 +3555,23 @@ def SPV_LoopControlAttr :
       SPV_LC_PipelineEnableINTEL, SPV_LC_NoFusionINTEL
     ]>;
 
-def SPV_MA_None                 : BitEnumAttrCase<"None", 0x0000>;
-def SPV_MA_Volatile             : BitEnumAttrCase<"Volatile", 0x0001>;
-def SPV_MA_Aligned              : BitEnumAttrCase<"Aligned", 0x0002>;
-def SPV_MA_Nontemporal          : BitEnumAttrCase<"Nontemporal", 0x0004>;
-def SPV_MA_MakePointerAvailable : BitEnumAttrCase<"MakePointerAvailable", 0x0008> {
+def SPV_MA_None                 : BitEnumAttrCaseNone<"None">;
+def SPV_MA_Volatile             : BitEnumAttrCaseBit<"Volatile", 0>;
+def SPV_MA_Aligned              : BitEnumAttrCaseBit<"Aligned", 1>;
+def SPV_MA_Nontemporal          : BitEnumAttrCaseBit<"Nontemporal", 2>;
+def SPV_MA_MakePointerAvailable : BitEnumAttrCaseBit<"MakePointerAvailable", 3> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MA_MakePointerVisible   : BitEnumAttrCase<"MakePointerVisible", 0x0010> {
+def SPV_MA_MakePointerVisible   : BitEnumAttrCaseBit<"MakePointerVisible", 4> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MA_NonPrivatePointer    : BitEnumAttrCase<"NonPrivatePointer", 0x0020> {
+def SPV_MA_NonPrivatePointer    : BitEnumAttrCaseBit<"NonPrivatePointer", 5> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
@@ -3612,44 +3612,44 @@ def SPV_MemoryModelAttr :
       SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan
     ]>;
 
-def SPV_MS_None                   : BitEnumAttrCase<"None", 0x0000>;
-def SPV_MS_Acquire                : BitEnumAttrCase<"Acquire", 0x0002>;
-def SPV_MS_Release                : BitEnumAttrCase<"Release", 0x0004>;
-def SPV_MS_AcquireRelease         : BitEnumAttrCase<"AcquireRelease", 0x0008>;
-def SPV_MS_SequentiallyConsistent : BitEnumAttrCase<"SequentiallyConsistent", 0x0010>;
-def SPV_MS_UniformMemory          : BitEnumAttrCase<"UniformMemory", 0x0040> {
+def SPV_MS_None                   : BitEnumAttrCaseNone<"None">;
+def SPV_MS_Acquire                : BitEnumAttrCaseBit<"Acquire", 1>;
+def SPV_MS_Release                : BitEnumAttrCaseBit<"Release", 2>;
+def SPV_MS_AcquireRelease         : BitEnumAttrCaseBit<"AcquireRelease", 3>;
+def SPV_MS_SequentiallyConsistent : BitEnumAttrCaseBit<"SequentiallyConsistent", 4>;
+def SPV_MS_UniformMemory          : BitEnumAttrCaseBit<"UniformMemory", 6> {
   list<Availability> availability = [
     Capability<[SPV_C_Shader]>
   ];
 }
-def SPV_MS_SubgroupMemory         : BitEnumAttrCase<"SubgroupMemory", 0x0080>;
-def SPV_MS_WorkgroupMemory        : BitEnumAttrCase<"WorkgroupMemory", 0x0100>;
-def SPV_MS_CrossWorkgroupMemory   : BitEnumAttrCase<"CrossWorkgroupMemory", 0x0200>;
-def SPV_MS_AtomicCounterMemory    : BitEnumAttrCase<"AtomicCounterMemory", 0x0400> {
+def SPV_MS_SubgroupMemory         : BitEnumAttrCaseBit<"SubgroupMemory", 7>;
+def SPV_MS_WorkgroupMemory        : BitEnumAttrCaseBit<"WorkgroupMemory", 8>;
+def SPV_MS_CrossWorkgroupMemory   : BitEnumAttrCaseBit<"CrossWorkgroupMemory", 9>;
+def SPV_MS_AtomicCounterMemory    : BitEnumAttrCaseBit<"AtomicCounterMemory", 10> {
   list<Availability> availability = [
     Capability<[SPV_C_AtomicStorage]>
   ];
 }
-def SPV_MS_ImageMemory            : BitEnumAttrCase<"ImageMemory", 0x0800>;
-def SPV_MS_OutputMemory           : BitEnumAttrCase<"OutputMemory", 0x1000> {
+def SPV_MS_ImageMemory            : BitEnumAttrCaseBit<"ImageMemory", 11>;
+def SPV_MS_OutputMemory           : BitEnumAttrCaseBit<"OutputMemory", 12> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MS_MakeAvailable          : BitEnumAttrCase<"MakeAvailable", 0x2000> {
+def SPV_MS_MakeAvailable          : BitEnumAttrCaseBit<"MakeAvailable", 13> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MS_MakeVisible            : BitEnumAttrCase<"MakeVisible", 0x4000> {
+def SPV_MS_MakeVisible            : BitEnumAttrCaseBit<"MakeVisible", 14> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MS_Volatile               : BitEnumAttrCase<"Volatile", 0x8000> {
+def SPV_MS_Volatile               : BitEnumAttrCaseBit<"Volatile", 15> {
   list<Availability> availability = [
     Extension<[SPV_KHR_vulkan_memory_model]>,
     Capability<[SPV_C_VulkanMemoryModel]>
@@ -3688,9 +3688,9 @@ def SPV_ScopeAttr :
       SPV_S_Invocation, SPV_S_QueueFamily, SPV_S_ShaderCallKHR
     ]>;
 
-def SPV_SC_None        : BitEnumAttrCase<"None", 0x0000>;
-def SPV_SC_Flatten     : BitEnumAttrCase<"Flatten", 0x0001>;
-def SPV_SC_DontFlatten : BitEnumAttrCase<"DontFlatten", 0x0002>;
+def SPV_SC_None        : BitEnumAttrCaseNone<"None">;
+def SPV_SC_Flatten     : BitEnumAttrCaseBit<"Flatten", 0>;
+def SPV_SC_DontFlatten : BitEnumAttrCaseBit<"DontFlatten", 1>;
 
 def SPV_SelectionControlAttr :
     SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index 826c7d0338f0b..1f501ac6b89ea 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -39,17 +39,17 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
 }
 
 // The "kind" of combining function for contractions and reductions.
-def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">;
-def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">;
-def COMBINING_KIND_MINUI : BitEnumAttrCase<"MINUI", 0x4, "minui">;
-def COMBINING_KIND_MINSI : BitEnumAttrCase<"MINSI", 0x8, "minsi">;
-def COMBINING_KIND_MINF : BitEnumAttrCase<"MINF", 0x10, "minf">;
-def COMBINING_KIND_MAXUI : BitEnumAttrCase<"MAXUI", 0x20, "maxui">;
-def COMBINING_KIND_MAXSI : BitEnumAttrCase<"MAXSI", 0x40, "maxsi">;
-def COMBINING_KIND_MAXF : BitEnumAttrCase<"MAXF", 0x80, "maxf">;
-def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x100, "and">;
-def COMBINING_KIND_OR  : BitEnumAttrCase<"OR", 0x200, "or">;
-def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x400, "xor">;
+def COMBINING_KIND_ADD : BitEnumAttrCaseBit<"ADD", 0, "add">;
+def COMBINING_KIND_MUL : BitEnumAttrCaseBit<"MUL", 1, "mul">;
+def COMBINING_KIND_MINUI : BitEnumAttrCaseBit<"MINUI", 2, "minui">;
+def COMBINING_KIND_MINSI : BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
+def COMBINING_KIND_MINF : BitEnumAttrCaseBit<"MINF", 4, "minf">;
+def COMBINING_KIND_MAXUI : BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
+def COMBINING_KIND_MAXSI : BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
+def COMBINING_KIND_MAXF : BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
+def COMBINING_KIND_AND : BitEnumAttrCaseBit<"AND", 8, "and">;
+def COMBINING_KIND_OR  : BitEnumAttrCaseBit<"OR", 9, "or">;
+def COMBINING_KIND_XOR : BitEnumAttrCaseBit<"XOR", 10, "xor">;
 
 def CombiningKind : BitEnumAttr<
     "CombiningKind",

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 0d99a9f8d5cb3..6bdf1b5fbc7b8 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1323,12 +1323,30 @@ class I64EnumAttrCase<string sym, int val, string str = sym>
 // A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
 // ordinal number of the bit that is set. It is the 32-bit integer with only
 // one bit set.
-class BitEnumAttrCase<string sym, int val, string str = sym> :
-    EnumAttrCaseInfo<sym, val, str>,
-    SignlessIntegerAttrBase<I32, "case " # str> {
-  let predicate = CPred<
-    "$_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & "
-    # val # "u">;
+class BitEnumAttrCase<string sym, int val, string str = sym>
+    : EnumAttrCaseInfo<sym, val, str>,
+      SignlessIntegerAttrBase<I32, "case " #str>;
+
+// The special bit enum case for no bits set (i.e. value = 0).
+class BitEnumAttrCaseNone<string sym, string str = sym>
+    : BitEnumAttrCase<sym, 0, str>;
+
+// The bit enum case for a single bit, specified by the bit position.
+// The pos argument refers to the index of the bit, and is currently
+// limited to be in the range [0, 31].
+class BitEnumAttrCaseBit<string sym, int pos, string str = sym>
+    : BitEnumAttrCase<sym, !shl(1, pos), str> {
+  assert !and(!ge(pos, 0), !le(pos, 31)),
+      "bit position must be between 0 and 31";
+}
+
+// A bit enum case for a group/list of previously declared single bits,
+// providing a convenient alias for that group.
+class BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBit> cases,
+                           string str = sym>
+    : BitEnumAttrCase<
+          sym, !foldl(0, cases, value, bitcase, !or(value, bitcase.value)),
+          str> {
 }
 
 // Additional information for an enum attribute.
@@ -1452,7 +1470,7 @@ class I64EnumAttr<string name, string summary, list<I64EnumAttrCase> cases> :
 // A bit enum stored with 32-bit IntegerAttr.
 //
 // Op attributes of this kind are stored as IntegerAttr. Extra verification will
-// be generated on the integer to make sure only allowed bit are set. Besides,
+// be generated on the integer to make sure only allowed bits are set. Besides,
 // helper methods are generated to parse a string separated with a specified
 // delimiter to a symbol and vice versa.
 class BitEnumAttrBase<list<BitEnumAttrCase> cases, string summary> :
@@ -1470,6 +1488,9 @@ class BitEnumAttr<string name, string summary, list<BitEnumAttrCase> cases> :
     EnumAttrInfo<name, cases, BitEnumAttrBase<cases, summary>> {
   let underlyingType = "uint32_t";
 
+  // Determine "valid" bits from enum cases for error checking
+  int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value));
+
   // We need to return a string because we may concatenate symbols for multiple
   // bits together.
   let symbolToStringFnRetType = "std::string";

diff  --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index aa8841abfd73b..722ab9dcdf04d 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -193,6 +193,11 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
 
   os << formatv("  auto val = static_cast<{0}>(symbol);\n",
                 enumAttr.getUnderlyingType());
+  // If we have unknown bit set, return an empty string to signal errors.
+  int64_t validBits = enumDef.getValueAsInt("validBits");
+  os << formatv("  assert({0}u == ({0}u | val) && \"invalid bits set in bit "
+                "enum\");\n",
+                validBits);
   if (allBitsUnsetCase) {
     os << "  // Special case for all bits unset.\n";
     os << formatv("  if (val == 0) return \"{0}\";\n\n",
@@ -201,13 +206,11 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
   os << "  ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
   for (const auto &enumerant : enumerants) {
     // Skip the special enumerant for None.
-    if (auto val = enumerant.getValue())
-      os << formatv("  if ({0}u & val) {{ strs.push_back(\"{1}\"); "
-                    "val &= ~{0}u; }\n",
-                    val, enumerant.getStr());
+    if (int64_t val = enumerant.getValue())
+      os << formatv(
+          "  if ({0}u == ({0}u & val)) {{ strs.push_back(\"{1}\"); }\n ", val,
+          enumerant.getStr());
   }
-  // If we have unknown bit set, return an empty string to signal errors.
-  os << "\n  if (val) return \"\";\n";
   os << formatv("  return ::llvm::join(strs, \"{0}\");\n", separator);
 
   os << "}\n\n";

diff  --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index a873658cdc3ab..764954242d55d 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -68,25 +68,25 @@ TEST(EnumsGenTest, GeneratedUnderlyingType) {
 
 TEST(EnumsGenTest, GeneratedBitEnumDefinition) {
   EXPECT_EQ(0u, static_cast<uint32_t>(BitEnumWithNone::None));
-  EXPECT_EQ(1u, static_cast<uint32_t>(BitEnumWithNone::Bit1));
-  EXPECT_EQ(4u, static_cast<uint32_t>(BitEnumWithNone::Bit3));
+  EXPECT_EQ(1u, static_cast<uint32_t>(BitEnumWithNone::Bit0));
+  EXPECT_EQ(8u, static_cast<uint32_t>(BitEnumWithNone::Bit3));
 }
 
 TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
   EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::None), "None");
-  EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit1), "Bit1");
+  EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit0), "Bit0");
   EXPECT_EQ(stringifyBitEnumWithNone(BitEnumWithNone::Bit3), "Bit3");
   EXPECT_EQ(
-      stringifyBitEnumWithNone(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3),
-      "Bit1|Bit3");
+      stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3),
+      "Bit0|Bit3");
 }
 
 TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
   EXPECT_EQ(symbolizeBitEnumWithNone("None"), BitEnumWithNone::None);
-  EXPECT_EQ(symbolizeBitEnumWithNone("Bit1"), BitEnumWithNone::Bit1);
+  EXPECT_EQ(symbolizeBitEnumWithNone("Bit0"), BitEnumWithNone::Bit0);
   EXPECT_EQ(symbolizeBitEnumWithNone("Bit3"), BitEnumWithNone::Bit3);
-  EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit1"),
-            BitEnumWithNone::Bit3 | BitEnumWithNone::Bit1);
+  EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit0"),
+            BitEnumWithNone::Bit3 | BitEnumWithNone::Bit0);
 
   EXPECT_EQ(symbolizeBitEnumWithNone("Bit2"), llvm::None);
   EXPECT_EQ(symbolizeBitEnumWithNone("Bit3|Bit4"), llvm::None);
@@ -94,11 +94,31 @@ TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
   EXPECT_EQ(symbolizeBitEnumWithoutNone("None"), llvm::None);
 }
 
+TEST(EnumsGenTest, GeneratedSymbolToStringFnForGroupedBitEnum) {
+  EXPECT_EQ(stringifyBitEnumWithGroup(BitEnumWithGroup::Bit0), "Bit0");
+  EXPECT_EQ(stringifyBitEnumWithGroup(BitEnumWithGroup::Bit3), "Bit3");
+  EXPECT_EQ(stringifyBitEnumWithGroup(BitEnumWithGroup::Bits0To3),
+            "Bit0|Bit1|Bit2|Bit3|Bits0To3");
+  EXPECT_EQ(stringifyBitEnumWithGroup(BitEnumWithGroup::Bit4), "Bit4");
+  EXPECT_EQ(stringifyBitEnumWithGroup(
+                BitEnumWithGroup::Bit0 | BitEnumWithGroup::Bit1 |
+                BitEnumWithGroup::Bit2 | BitEnumWithGroup::Bit4),
+            "Bit0|Bit1|Bit2|Bit4");
+}
+
+TEST(EnumsGenTest, GeneratedStringToSymbolForGroupedBitEnum) {
+  EXPECT_EQ(symbolizeBitEnumWithGroup("Bit0"), BitEnumWithGroup::Bit0);
+  EXPECT_EQ(symbolizeBitEnumWithGroup("Bit3"), BitEnumWithGroup::Bit3);
+  EXPECT_EQ(symbolizeBitEnumWithGroup("Bit5"), llvm::None);
+  EXPECT_EQ(symbolizeBitEnumWithGroup("Bit3|Bit0"),
+            BitEnumWithGroup::Bit3 | BitEnumWithGroup::Bit0);
+}
+
 TEST(EnumsGenTest, GeneratedOperator) {
-  EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3,
-                              BitEnumWithNone::Bit1));
-  EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit1 & BitEnumWithNone::Bit3,
-                               BitEnumWithNone::Bit1));
+  EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
+                              BitEnumWithNone::Bit0));
+  EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit0 & BitEnumWithNone::Bit3,
+                               BitEnumWithNone::Bit0));
 }
 
 TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) {
@@ -152,7 +172,11 @@ TEST(EnumsGenTest, GeneratedBitAttributeClass) {
   mlir::Type intType = mlir::IntegerType::get(&ctx, 32);
   mlir::Attribute intAttr = mlir::IntegerAttr::get(
       intType,
-      static_cast<uint32_t>(BitEnumWithNone::Bit1 | BitEnumWithNone::Bit3));
+      static_cast<uint32_t>(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3));
   EXPECT_TRUE(intAttr.isa<BitEnumWithNoneAttr>());
   EXPECT_TRUE(intAttr.isa<BitEnumWithoutNoneAttr>());
+
+  intAttr = mlir::IntegerAttr::get(
+      intType, static_cast<uint32_t>(BitEnumWithGroup::Bits0To3) | (1u << 6));
+  EXPECT_FALSE(intAttr.isa<BitEnumWithGroupAttr>());
 }

diff  --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td
index cdcc18254bd92..9dbb1b47870ce 100644
--- a/mlir/unittests/TableGen/enums.td
+++ b/mlir/unittests/TableGen/enums.td
@@ -23,15 +23,25 @@ def Case10: I32EnumAttrCase<"Case10", 10>;
 
 def I32Enum: I32EnumAttr<"I32Enum", "A test enum", [Case5, Case10]>;
 
-def Bit0 : BitEnumAttrCase<"None", 0x0000>;
-def Bit1 : BitEnumAttrCase<"Bit1", 0x0001>;
-def Bit3 : BitEnumAttrCase<"Bit3", 0x0004>;
+def NoBits : BitEnumAttrCaseNone<"None">;
+def Bit0 : BitEnumAttrCaseBit<"Bit0", 0>;
+def Bit1 : BitEnumAttrCaseBit<"Bit1", 1>;
+def Bit2 : BitEnumAttrCaseBit<"Bit2", 2>;
+def Bit3 : BitEnumAttrCaseBit<"Bit3", 3>;
+def Bit4 : BitEnumAttrCaseBit<"Bit4", 4>;
+def Bit5 : BitEnumAttrCaseBit<"Bit5", 5>;
 
 def BitEnumWithNone : BitEnumAttr<"BitEnumWithNone", "A test enum",
-                                  [Bit0, Bit1, Bit3]>;
+                                  [NoBits, Bit0, Bit3]>;
 
 def BitEnumWithoutNone : BitEnumAttr<"BitEnumWithoutNone", "A test enum",
-                                     [Bit1, Bit3]>;
+                                     [Bit0, Bit3]>;
+
+def Bits0To3 : BitEnumAttrCaseGroup<"Bits0To3",
+                                    [Bit0, Bit1, Bit2, Bit3]>;
+
+def BitEnumWithGroup : BitEnumAttr<"BitEnumWithGroup", "A test enum",
+                                   [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>;
 
 def PrettyIntEnumCase1: I32EnumAttrCase<"Case1", 1, "case_one">;
 def PrettyIntEnumCase2: I32EnumAttrCase<"Case2", 2, "case_two">;


        


More information about the Mlir-commits mailing list