[Mlir-commits] [mlir] 21949de - [mlir] Width parameterization of BitEnum attributes

Mehdi Amini llvmlistbot at llvm.org
Thu Apr 7 18:21:40 PDT 2022


Author: Jeremy Furtek
Date: 2022-04-08T01:21:29Z
New Revision: 21949de62fa5ff71f24766f49ba09ddf9d65bd28

URL: https://github.com/llvm/llvm-project/commit/21949de62fa5ff71f24766f49ba09ddf9d65bd28
DIFF: https://github.com/llvm/llvm-project/commit/21949de62fa5ff71f24766f49ba09ddf9d65bd28.diff

LOG: [mlir] Width parameterization of BitEnum attributes

This diff contains:

- Parameterization of bit enum attributes in OpBase.td by bit width (e.g. 32
and 64). Previously, all enums were 32-bits. This brings enum functionality in
line with other integer attributes, and allows for bit enums greater than 32
bits.
- SPIRV and Vector dialects were updated to use bit enum attributes with an
  explicit bit width

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D123095

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/include/mlir/IR/OpBase.td
    mlir/tools/mlir-tblgen/EnumsGen.cpp
    mlir/unittests/TableGen/EnumsGenTest.cpp
    mlir/unittests/TableGen/enums.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 5b2233688780a..a0b6dcd026602 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -21,16 +21,16 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
-def FMFnnan     : BitEnumAttrCaseBit<"nnan", 0>;
-def FMFninf     : BitEnumAttrCaseBit<"ninf", 1>;
-def FMFnsz      : BitEnumAttrCaseBit<"nsz", 2>;
-def FMFarcp     : BitEnumAttrCaseBit<"arcp", 3>;
-def FMFcontract : BitEnumAttrCaseBit<"contract", 4>;
-def FMFafn      : BitEnumAttrCaseBit<"afn", 5>;
-def FMFreassoc  : BitEnumAttrCaseBit<"reassoc", 6>;
-def FMFfast     : BitEnumAttrCaseBit<"fast", 7>;
-
-def FastmathFlags_DoNotUse : BitEnumAttr<
+def FMFnnan     : I32BitEnumAttrCaseBit<"nnan", 0>;
+def FMFninf     : I32BitEnumAttrCaseBit<"ninf", 1>;
+def FMFnsz      : I32BitEnumAttrCaseBit<"nsz", 2>;
+def FMFarcp     : I32BitEnumAttrCaseBit<"arcp", 3>;
+def FMFcontract : I32BitEnumAttrCaseBit<"contract", 4>;
+def FMFafn      : I32BitEnumAttrCaseBit<"afn", 5>;
+def FMFreassoc  : I32BitEnumAttrCaseBit<"reassoc", 6>;
+def FMFfast     : I32BitEnumAttrCaseBit<"fast", 7>;
+
+def FastmathFlags : I32BitEnumAttr<
     "FastmathFlags",
     "LLVM fastmath flags",
     [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 35f0d955833e4..616aeea03c6b8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -78,8 +78,8 @@ class SPV_IsKnownEnumCaseFor<string name> :
 
 // Wrapper over base BitEnumAttr to set common fields.
 class SPV_BitEnumAttr<string name, string description,
-                      list<BitEnumAttrCase> cases> :
-    BitEnumAttr<name, description, cases> {
+                      list<BitEnumAttrCaseBase> cases> :
+    I32BitEnumAttr<name, description, cases> {
   let predicate = And<[
     I32Attr.predicate,
     SPV_IsKnownEnumCaseFor<name>,
@@ -3083,12 +3083,12 @@ def SPV_ExecutionModelAttr :
       SPV_EM_AnyHitKHR, SPV_EM_ClosestHitKHR, SPV_EM_MissKHR, SPV_EM_CallableKHR
     ]>;
 
-def SPV_FC_None         : BitEnumAttrCaseNone<"None">;
-def SPV_FC_Inline       : BitEnumAttrCaseBit<"Inline", 0>;
-def SPV_FC_DontInline   : BitEnumAttrCaseBit<"DontInline", 1>;
-def SPV_FC_Pure         : BitEnumAttrCaseBit<"Pure", 2>;
-def SPV_FC_Const        : BitEnumAttrCaseBit<"Const", 3>;
-def SPV_FC_OptNoneINTEL : BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
+def SPV_FC_None         : I32BitEnumAttrCaseNone<"None">;
+def SPV_FC_Inline       : I32BitEnumAttrCaseBit<"Inline", 0>;
+def SPV_FC_DontInline   : I32BitEnumAttrCaseBit<"DontInline", 1>;
+def SPV_FC_Pure         : I32BitEnumAttrCaseBit<"Pure", 2>;
+def SPV_FC_Const        : I32BitEnumAttrCaseBit<"Const", 3>;
+def SPV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
   list<Availability> availability = [
     Capability<[SPV_C_OptNoneINTEL]>
   ];
@@ -3367,62 +3367,62 @@ def SPV_ImageFormatAttr :
       SPV_IF_R8ui, SPV_IF_R64ui, SPV_IF_R64i
     ]>;
 
-def SPV_IO_None               : BitEnumAttrCaseNone<"None">;
-def SPV_IO_Bias               : BitEnumAttrCaseBit<"Bias", 0> {
+def SPV_IO_None               : I32BitEnumAttrCaseNone<"None">;
+def SPV_IO_Bias               : I32BitEnumAttrCaseBit<"Bias", 0> {
   list<Availability> availability = [
     Capability<[SPV_C_Shader]>
   ];
 }
-def SPV_IO_Lod                : BitEnumAttrCaseBit<"Lod", 1>;
-def SPV_IO_Grad               : BitEnumAttrCaseBit<"Grad", 2>;
-def SPV_IO_ConstOffset        : BitEnumAttrCaseBit<"ConstOffset", 3>;
-def SPV_IO_Offset             : BitEnumAttrCaseBit<"Offset", 4> {
+def SPV_IO_Lod                : I32BitEnumAttrCaseBit<"Lod", 1>;
+def SPV_IO_Grad               : I32BitEnumAttrCaseBit<"Grad", 2>;
+def SPV_IO_ConstOffset        : I32BitEnumAttrCaseBit<"ConstOffset", 3>;
+def SPV_IO_Offset             : I32BitEnumAttrCaseBit<"Offset", 4> {
   list<Availability> availability = [
     Capability<[SPV_C_ImageGatherExtended]>
   ];
 }
-def SPV_IO_ConstOffsets       : BitEnumAttrCaseBit<"ConstOffsets", 5> {
+def SPV_IO_ConstOffsets       : I32BitEnumAttrCaseBit<"ConstOffsets", 5> {
   list<Availability> availability = [
     Capability<[SPV_C_ImageGatherExtended]>
   ];
 }
-def SPV_IO_Sample             : BitEnumAttrCaseBit<"Sample", 6>;
-def SPV_IO_MinLod             : BitEnumAttrCaseBit<"MinLod", 7> {
+def SPV_IO_Sample             : I32BitEnumAttrCaseBit<"Sample", 6>;
+def SPV_IO_MinLod             : I32BitEnumAttrCaseBit<"MinLod", 7> {
   list<Availability> availability = [
     Capability<[SPV_C_MinLod]>
   ];
 }
-def SPV_IO_MakeTexelAvailable : BitEnumAttrCaseBit<"MakeTexelAvailable", 8> {
+def SPV_IO_MakeTexelAvailable : I32BitEnumAttrCaseBit<"MakeTexelAvailable", 8> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_IO_MakeTexelVisible   : BitEnumAttrCaseBit<"MakeTexelVisible", 9> {
+def SPV_IO_MakeTexelVisible   : I32BitEnumAttrCaseBit<"MakeTexelVisible", 9> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_IO_NonPrivateTexel    : BitEnumAttrCaseBit<"NonPrivateTexel", 10> {
+def SPV_IO_NonPrivateTexel    : I32BitEnumAttrCaseBit<"NonPrivateTexel", 10> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_IO_VolatileTexel      : BitEnumAttrCaseBit<"VolatileTexel", 11> {
+def SPV_IO_VolatileTexel      : I32BitEnumAttrCaseBit<"VolatileTexel", 11> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_IO_SignExtend         : BitEnumAttrCaseBit<"SignExtend", 12> {
+def SPV_IO_SignExtend         : I32BitEnumAttrCaseBit<"SignExtend", 12> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_IO_Offsets            : BitEnumAttrCaseBit<"Offsets", 16>;
-def SPV_IO_ZeroExtend         : BitEnumAttrCaseBit<"ZeroExtend", 13> {
+def SPV_IO_Offsets            : I32BitEnumAttrCaseBit<"Offsets", 16>;
+def SPV_IO_ZeroExtend         : I32BitEnumAttrCaseBit<"ZeroExtend", 13> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
@@ -3458,87 +3458,87 @@ def SPV_LinkageTypeAttr :
       SPV_LT_Export, SPV_LT_Import, SPV_LT_LinkOnceODR
     ]>;
 
-def SPV_LC_None                      : BitEnumAttrCaseNone<"None">;
-def SPV_LC_Unroll                    : BitEnumAttrCaseBit<"Unroll", 0>;
-def SPV_LC_DontUnroll                : BitEnumAttrCaseBit<"DontUnroll", 1>;
-def SPV_LC_DependencyInfinite        : BitEnumAttrCaseBit<"DependencyInfinite", 2> {
+def SPV_LC_None                      : I32BitEnumAttrCaseNone<"None">;
+def SPV_LC_Unroll                    : I32BitEnumAttrCaseBit<"Unroll", 0>;
+def SPV_LC_DontUnroll                : I32BitEnumAttrCaseBit<"DontUnroll", 1>;
+def SPV_LC_DependencyInfinite        : I32BitEnumAttrCaseBit<"DependencyInfinite", 2> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_1>
   ];
 }
-def SPV_LC_DependencyLength          : BitEnumAttrCaseBit<"DependencyLength", 3> {
+def SPV_LC_DependencyLength          : I32BitEnumAttrCaseBit<"DependencyLength", 3> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_1>
   ];
 }
-def SPV_LC_MinIterations             : BitEnumAttrCaseBit<"MinIterations", 4> {
+def SPV_LC_MinIterations             : I32BitEnumAttrCaseBit<"MinIterations", 4> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_MaxIterations             : BitEnumAttrCaseBit<"MaxIterations", 5> {
+def SPV_LC_MaxIterations             : I32BitEnumAttrCaseBit<"MaxIterations", 5> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_IterationMultiple         : BitEnumAttrCaseBit<"IterationMultiple", 6> {
+def SPV_LC_IterationMultiple         : I32BitEnumAttrCaseBit<"IterationMultiple", 6> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_PeelCount                 : BitEnumAttrCaseBit<"PeelCount", 7> {
+def SPV_LC_PeelCount                 : I32BitEnumAttrCaseBit<"PeelCount", 7> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_PartialCount              : BitEnumAttrCaseBit<"PartialCount", 8> {
+def SPV_LC_PartialCount              : I32BitEnumAttrCaseBit<"PartialCount", 8> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_4>
   ];
 }
-def SPV_LC_InitiationIntervalINTEL   : BitEnumAttrCaseBit<"InitiationIntervalINTEL", 16> {
+def SPV_LC_InitiationIntervalINTEL   : I32BitEnumAttrCaseBit<"InitiationIntervalINTEL", 16> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_LoopCoalesceINTEL         : BitEnumAttrCaseBit<"LoopCoalesceINTEL", 20> {
+def SPV_LC_LoopCoalesceINTEL         : I32BitEnumAttrCaseBit<"LoopCoalesceINTEL", 20> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_MaxConcurrencyINTEL       : BitEnumAttrCaseBit<"MaxConcurrencyINTEL", 17> {
+def SPV_LC_MaxConcurrencyINTEL       : I32BitEnumAttrCaseBit<"MaxConcurrencyINTEL", 17> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_MaxInterleavingINTEL      : BitEnumAttrCaseBit<"MaxInterleavingINTEL", 21> {
+def SPV_LC_MaxInterleavingINTEL      : I32BitEnumAttrCaseBit<"MaxInterleavingINTEL", 21> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_DependencyArrayINTEL      : BitEnumAttrCaseBit<"DependencyArrayINTEL", 18> {
+def SPV_LC_DependencyArrayINTEL      : I32BitEnumAttrCaseBit<"DependencyArrayINTEL", 18> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_SpeculatedIterationsINTEL : BitEnumAttrCaseBit<"SpeculatedIterationsINTEL", 22> {
+def SPV_LC_SpeculatedIterationsINTEL : I32BitEnumAttrCaseBit<"SpeculatedIterationsINTEL", 22> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_PipelineEnableINTEL       : BitEnumAttrCaseBit<"PipelineEnableINTEL", 19> {
+def SPV_LC_PipelineEnableINTEL       : I32BitEnumAttrCaseBit<"PipelineEnableINTEL", 19> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
   ];
 }
-def SPV_LC_NoFusionINTEL             : BitEnumAttrCaseBit<"NoFusionINTEL", 23> {
+def SPV_LC_NoFusionINTEL             : I32BitEnumAttrCaseBit<"NoFusionINTEL", 23> {
   list<Availability> availability = [
     Extension<[SPV_INTEL_fpga_loop_controls]>,
     Capability<[SPV_C_FPGALoopControlsINTEL]>
@@ -3556,23 +3556,23 @@ def SPV_LoopControlAttr :
       SPV_LC_PipelineEnableINTEL, SPV_LC_NoFusionINTEL
     ]>;
 
-def SPV_MA_None                 : BitEnumAttrCaseNone<"None">;
-def SPV_MA_Volatile             : BitEnumAttrCaseBit<"Volatile", 0>;
-def SPV_MA_Aligned              : BitEnumAttrCaseBit<"Aligned", 1>;
-def SPV_MA_Nontemporal          : BitEnumAttrCaseBit<"Nontemporal", 2>;
-def SPV_MA_MakePointerAvailable : BitEnumAttrCaseBit<"MakePointerAvailable", 3> {
+def SPV_MA_None                 : I32BitEnumAttrCaseNone<"None">;
+def SPV_MA_Volatile             : I32BitEnumAttrCaseBit<"Volatile", 0>;
+def SPV_MA_Aligned              : I32BitEnumAttrCaseBit<"Aligned", 1>;
+def SPV_MA_Nontemporal          : I32BitEnumAttrCaseBit<"Nontemporal", 2>;
+def SPV_MA_MakePointerAvailable : I32BitEnumAttrCaseBit<"MakePointerAvailable", 3> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MA_MakePointerVisible   : BitEnumAttrCaseBit<"MakePointerVisible", 4> {
+def SPV_MA_MakePointerVisible   : I32BitEnumAttrCaseBit<"MakePointerVisible", 4> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MA_NonPrivatePointer    : BitEnumAttrCaseBit<"NonPrivatePointer", 5> {
+def SPV_MA_NonPrivatePointer    : I32BitEnumAttrCaseBit<"NonPrivatePointer", 5> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
@@ -3613,44 +3613,44 @@ def SPV_MemoryModelAttr :
       SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan
     ]>;
 
-def SPV_MS_None                   : BitEnumAttrCaseNone<"None">;
-def SPV_MS_Acquire                : BitEnumAttrCaseBit<"Acquire", 1>;
-def SPV_MS_Release                : BitEnumAttrCaseBit<"Release", 2>;
-def SPV_MS_AcquireRelease         : BitEnumAttrCaseBit<"AcquireRelease", 3>;
-def SPV_MS_SequentiallyConsistent : BitEnumAttrCaseBit<"SequentiallyConsistent", 4>;
-def SPV_MS_UniformMemory          : BitEnumAttrCaseBit<"UniformMemory", 6> {
+def SPV_MS_None                   : I32BitEnumAttrCaseNone<"None">;
+def SPV_MS_Acquire                : I32BitEnumAttrCaseBit<"Acquire", 1>;
+def SPV_MS_Release                : I32BitEnumAttrCaseBit<"Release", 2>;
+def SPV_MS_AcquireRelease         : I32BitEnumAttrCaseBit<"AcquireRelease", 3>;
+def SPV_MS_SequentiallyConsistent : I32BitEnumAttrCaseBit<"SequentiallyConsistent", 4>;
+def SPV_MS_UniformMemory          : I32BitEnumAttrCaseBit<"UniformMemory", 6> {
   list<Availability> availability = [
     Capability<[SPV_C_Shader]>
   ];
 }
-def SPV_MS_SubgroupMemory         : BitEnumAttrCaseBit<"SubgroupMemory", 7>;
-def SPV_MS_WorkgroupMemory        : BitEnumAttrCaseBit<"WorkgroupMemory", 8>;
-def SPV_MS_CrossWorkgroupMemory   : BitEnumAttrCaseBit<"CrossWorkgroupMemory", 9>;
-def SPV_MS_AtomicCounterMemory    : BitEnumAttrCaseBit<"AtomicCounterMemory", 10> {
+def SPV_MS_SubgroupMemory         : I32BitEnumAttrCaseBit<"SubgroupMemory", 7>;
+def SPV_MS_WorkgroupMemory        : I32BitEnumAttrCaseBit<"WorkgroupMemory", 8>;
+def SPV_MS_CrossWorkgroupMemory   : I32BitEnumAttrCaseBit<"CrossWorkgroupMemory", 9>;
+def SPV_MS_AtomicCounterMemory    : I32BitEnumAttrCaseBit<"AtomicCounterMemory", 10> {
   list<Availability> availability = [
     Capability<[SPV_C_AtomicStorage]>
   ];
 }
-def SPV_MS_ImageMemory            : BitEnumAttrCaseBit<"ImageMemory", 11>;
-def SPV_MS_OutputMemory           : BitEnumAttrCaseBit<"OutputMemory", 12> {
+def SPV_MS_ImageMemory            : I32BitEnumAttrCaseBit<"ImageMemory", 11>;
+def SPV_MS_OutputMemory           : I32BitEnumAttrCaseBit<"OutputMemory", 12> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MS_MakeAvailable          : BitEnumAttrCaseBit<"MakeAvailable", 13> {
+def SPV_MS_MakeAvailable          : I32BitEnumAttrCaseBit<"MakeAvailable", 13> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MS_MakeVisible            : BitEnumAttrCaseBit<"MakeVisible", 14> {
+def SPV_MS_MakeVisible            : I32BitEnumAttrCaseBit<"MakeVisible", 14> {
   list<Availability> availability = [
     MinVersion<SPV_V_1_5>,
     Capability<[SPV_C_VulkanMemoryModel]>
   ];
 }
-def SPV_MS_Volatile               : BitEnumAttrCaseBit<"Volatile", 15> {
+def SPV_MS_Volatile               : I32BitEnumAttrCaseBit<"Volatile", 15> {
   list<Availability> availability = [
     Extension<[SPV_KHR_vulkan_memory_model]>,
     Capability<[SPV_C_VulkanMemoryModel]>
@@ -3689,9 +3689,9 @@ def SPV_ScopeAttr :
       SPV_S_Invocation, SPV_S_QueueFamily, SPV_S_ShaderCallKHR
     ]>;
 
-def SPV_SC_None        : BitEnumAttrCaseNone<"None">;
-def SPV_SC_Flatten     : BitEnumAttrCaseBit<"Flatten", 0>;
-def SPV_SC_DontFlatten : BitEnumAttrCaseBit<"DontFlatten", 1>;
+def SPV_SC_None        : I32BitEnumAttrCaseNone<"None">;
+def SPV_SC_Flatten     : I32BitEnumAttrCaseBit<"Flatten", 0>;
+def SPV_SC_DontFlatten : I32BitEnumAttrCaseBit<"DontFlatten", 1>;
 
 def SPV_SelectionControlAttr :
     SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [

diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 4ffbb1a93d8c3..f45754e72a72a 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -31,19 +31,19 @@ class Vector_Op<string mnemonic, list<Trait> traits = []> :
     Op<Vector_Dialect, mnemonic, traits>;
 
 // The "kind" of combining function for contractions and reductions.
-def COMBINING_KIND_ADD : BitEnumAttrCaseBit<"ADD", 0, "add">;
-def COMBINING_KIND_MUL : BitEnumAttrCaseBit<"MUL", 1, "mul">;
-def COMBINING_KIND_MINUI : BitEnumAttrCaseBit<"MINUI", 2, "minui">;
-def COMBINING_KIND_MINSI : BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
-def COMBINING_KIND_MINF : BitEnumAttrCaseBit<"MINF", 4, "minf">;
-def COMBINING_KIND_MAXUI : BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
-def COMBINING_KIND_MAXSI : BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
-def COMBINING_KIND_MAXF : BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
-def COMBINING_KIND_AND : BitEnumAttrCaseBit<"AND", 8, "and">;
-def COMBINING_KIND_OR  : BitEnumAttrCaseBit<"OR", 9, "or">;
-def COMBINING_KIND_XOR : BitEnumAttrCaseBit<"XOR", 10, "xor">;
-
-def CombiningKind : BitEnumAttr<
+def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
+def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
+def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
+def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
+def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
+def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
+def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
+def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
+def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
+def COMBINING_KIND_OR  : I32BitEnumAttrCaseBit<"OR", 9, "or">;
+def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
+
+def CombiningKind : I32BitEnumAttr<
     "CombiningKind",
     "Kind of combining function for contractions and reductions",
     [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 416d05d49c39a..92eadefaee0e5 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1334,34 +1334,72 @@ class I32EnumAttrCase<string sym, int val, string str = sym>
 class I64EnumAttrCase<string sym, int val, string str = sym>
     : IntEnumAttrCaseBase<I64, sym, str, val>;
 
-// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
-// ordinal number of the bit that is set. It is the 32-bit integer with only
-// one bit set.
-class BitEnumAttrCase<string sym, int val, string str = sym>
-    : EnumAttrCaseInfo<sym, val, str>,
-      SignlessIntegerAttrBase<I32, "case " #str>;
-
-// The special bit enum case for no bits set (i.e. value = 0).
-class BitEnumAttrCaseNone<string sym, string str = sym>
-    : BitEnumAttrCase<sym, 0, str>;
-
-// The bit enum case for a single bit, specified by the bit position.
-// The pos argument refers to the index of the bit, and is currently
-// limited to be in the range [0, 31].
-class BitEnumAttrCaseBit<string sym, int pos, string str = sym>
-    : BitEnumAttrCase<sym, !shl(1, pos), str> {
-  assert !and(!ge(pos, 0), !le(pos, 31)),
-      "bit position must be between 0 and 31";
-}
-
-// A bit enum case for a group/list of previously declared single bits,
-// providing a convenient alias for that group.
-class BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBit> cases,
-                           string str = sym>
-    : BitEnumAttrCase<
-          sym, !foldl(0, cases, value, bitcase, !or(value, bitcase.value)),
-          str> {
-}
+// A bit enum case stored with an IntegerAttr. `val` here is *not* the ordinal
+// number of a bit that is set. It is an integer value with bits set to match
+// the case.
+class BitEnumAttrCaseBase<I intType, string sym, int val, string str = sym> :
+    EnumAttrCaseInfo<sym, val, str>,
+    SignlessIntegerAttrBase<intType, "case " #str>;
+
+// A bit enum case stored with a 32-bit IntegerAttr. `val` here is *not* the
+// ordinal number of a bit that is set. It is a 32-bit integer value with bits
+// set to match the case. 
+class I32BitEnumAttrCase<string sym, int val, string str = sym>
+    : BitEnumAttrCaseBase<I32, sym, val, str>;
+
+// A bit enum case stored with a 64-bit IntegerAttr. `val` here is *not* the
+// ordinal number of a bit that is set. It is a 64-bit integer value with bits
+// bits set to match the case.
+class I64BitEnumAttrCase<string sym, int val, string str = sym>
+    : BitEnumAttrCaseBase<I64, sym, val, str>;
+
+// The special bit enum case for I32 with no bits set (i.e. value = 0).
+class I32BitEnumAttrCaseNone<string sym, string str = sym>
+    : I32BitEnumAttrCase<sym, 0, str>;
+
+// The special bit enum case for I64 with no bits set (i.e. value = 0).
+class I64BitEnumAttrCaseNone<string sym, string str = sym>
+    : I64BitEnumAttrCase<sym, 0, str>;
+
+// A bit enum case for a single bit, specified by a bit position.
+// The pos argument refers to the index of the bit, and is limited
+// to be in the range [0, bitwidth).
+class BitEnumAttrCaseBit<I intType, string sym, int pos, string str = sym>
+    : BitEnumAttrCaseBase<intType, sym, !shl(1, pos), str> {
+  assert !and(!ge(pos, 0), !lt(pos, intType.bitwidth)),
+      "bit position larger than underlying storage";
+}
+
+// A bit enum case for a single bit in a 32-bit enum, specified by the
+// bit position.
+class I32BitEnumAttrCaseBit<string sym, int pos, string str = sym>
+    : BitEnumAttrCaseBit<I32, sym, pos, str>;
+
+// A bit enum case for a single bit in a 64-bit enum, specified by the
+// bit position.
+class I64BitEnumAttrCaseBit<string sym, int pos, string str = sym>
+    : BitEnumAttrCaseBit<I64, sym, pos, str>;
+
+
+// A bit enum case for a group/list of previously declared cases, providing
+// a convenient alias for that group.
+class BitEnumAttrCaseGroup<I intType, string sym,
+                           list<BitEnumAttrCaseBase> cases, string str = sym>
+    : BitEnumAttrCaseBase<intType, sym,
+          !foldl(0, cases, value, bitcase, !or(value, bitcase.value)),
+          str>;
+
+// A 32-bit enum case for a group/list of previously declared cases, providing
+// a convenient alias for that group.
+class I32BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
+                              string str = sym>
+    : BitEnumAttrCaseGroup<I32, sym, cases, str>;
+
+// A 64-bit enum case for a group/list of previously declared cases, providing
+// a convenient alias for that group.
+class I64BitEnumAttrCaseGroup<string sym, list<BitEnumAttrCaseBase> cases,
+                              string str = sym>
+    : BitEnumAttrCaseGroup<I64, sym, cases, str>;
 
 // Additional information for an enum attribute.
 class EnumAttrInfo<
@@ -1481,16 +1519,17 @@ class I64EnumAttr<string name, string summary, list<I64EnumAttrCase> cases> :
   let underlyingType = "uint64_t";
 }
 
-// A bit enum stored with 32-bit IntegerAttr.
+// A bit enum stored with an IntegerAttr.
 //
 // Op attributes of this kind are stored as IntegerAttr. Extra verification will
 // be generated on the integer to make sure only allowed bits are set. Besides,
 // helper methods are generated to parse a string separated with a specified
 // delimiter to a symbol and vice versa.
-class BitEnumAttrBase<list<BitEnumAttrCase> cases, string summary> :
-    SignlessIntegerAttrBase<I32, summary> {
+class BitEnumAttrBase<I intType, list<BitEnumAttrCaseBase> cases,
+                      string summary>
+    : SignlessIntegerAttrBase<intType, summary> {
   let predicate = And<[
-    I32Attr.predicate,
+    SignlessIntegerAttrBase<intType, summary>.predicate,
     // Make sure we don't have unknown bit set.
     CPred<"!($_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & (~("
           # !interleave(!foreach(case, cases, case.value # "u"), "|") #
@@ -1498,10 +1537,9 @@ class BitEnumAttrBase<list<BitEnumAttrCase> cases, string summary> :
   ]>;
 }
 
-class BitEnumAttr<string name, string summary, list<BitEnumAttrCase> cases> :
-    EnumAttrInfo<name, cases, BitEnumAttrBase<cases, summary>> {
-  let underlyingType = "uint32_t";
-
+class BitEnumAttr<I intType, string name, string summary,
+                  list<BitEnumAttrCaseBase> cases>
+    : EnumAttrInfo<name, cases, BitEnumAttrBase<intType, cases, summary>> {
   // Determine "valid" bits from enum cases for error checking
   int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value));
 
@@ -1513,6 +1551,18 @@ class BitEnumAttr<string name, string summary, list<BitEnumAttrCase> cases> :
   string separator = "|";
 }
 
+class I32BitEnumAttr<string name, string summary,
+                     list<BitEnumAttrCaseBase> cases>
+    : BitEnumAttr<I32, name, summary, cases> {
+  let underlyingType = "uint32_t";
+}
+
+class I64BitEnumAttr<string name, string summary,
+                     list<BitEnumAttrCaseBase> cases>
+    : BitEnumAttr<I64, name, summary, cases> {
+  let underlyingType = "uint64_t";
+}
+
 //===----------------------------------------------------------------------===//
 // Composite attribute kinds
 

diff  --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 722ab9dcdf04d..3365ff02b0df0 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -404,8 +404,8 @@ static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
     if (auto val = enumerant.getValue())
       values.push_back(std::string(formatv("{0}u", val)));
   }
-  os << formatv("  if (value & ~({0})) return llvm::None;\n",
-                llvm::join(values, " | "));
+  os << formatv("  if (value & ~static_cast<{0}>({1})) return llvm::None;\n",
+                underlyingType, llvm::join(values, " | "));
   os << formatv("  return static_cast<{0}>(value);\n", enumName);
   os << "}\n";
 }

diff  --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index 764954242d55d..82dbe119cb846 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -79,6 +79,8 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
   EXPECT_EQ(
       stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3),
       "Bit0|Bit3");
+  EXPECT_EQ(2u, static_cast<uint64_t>(BitEnum64_Test::Bit1));
+  EXPECT_EQ(144115188075855872u, static_cast<uint64_t>(BitEnum64_Test::Bit57));
 }
 
 TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {

diff  --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td
index 9dbb1b47870ce..142f41403ce9f 100644
--- a/mlir/unittests/TableGen/enums.td
+++ b/mlir/unittests/TableGen/enums.td
@@ -23,25 +23,31 @@ def Case10: I32EnumAttrCase<"Case10", 10>;
 
 def I32Enum: I32EnumAttr<"I32Enum", "A test enum", [Case5, Case10]>;
 
-def NoBits : BitEnumAttrCaseNone<"None">;
-def Bit0 : BitEnumAttrCaseBit<"Bit0", 0>;
-def Bit1 : BitEnumAttrCaseBit<"Bit1", 1>;
-def Bit2 : BitEnumAttrCaseBit<"Bit2", 2>;
-def Bit3 : BitEnumAttrCaseBit<"Bit3", 3>;
-def Bit4 : BitEnumAttrCaseBit<"Bit4", 4>;
-def Bit5 : BitEnumAttrCaseBit<"Bit5", 5>;
-
-def BitEnumWithNone : BitEnumAttr<"BitEnumWithNone", "A test enum",
-                                  [NoBits, Bit0, Bit3]>;
-
-def BitEnumWithoutNone : BitEnumAttr<"BitEnumWithoutNone", "A test enum",
-                                     [Bit0, Bit3]>;
-
-def Bits0To3 : BitEnumAttrCaseGroup<"Bits0To3",
-                                    [Bit0, Bit1, Bit2, Bit3]>;
-
-def BitEnumWithGroup : BitEnumAttr<"BitEnumWithGroup", "A test enum",
-                                   [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>;
+def NoBits : I32BitEnumAttrCaseNone<"None">;
+def Bit0 : I32BitEnumAttrCaseBit<"Bit0", 0>;
+def Bit1 : I32BitEnumAttrCaseBit<"Bit1", 1>;
+def Bit2 : I32BitEnumAttrCaseBit<"Bit2", 2>;
+def Bit3 : I32BitEnumAttrCaseBit<"Bit3", 3>;
+def Bit4 : I32BitEnumAttrCaseBit<"Bit4", 4>;
+def Bit5 : I32BitEnumAttrCaseBit<"Bit5", 5>;
+
+def BitEnumWithNone : I32BitEnumAttr<"BitEnumWithNone", "A test enum",
+                                     [NoBits, Bit0, Bit3]>;
+
+def BitEnumWithoutNone : I32BitEnumAttr<"BitEnumWithoutNone", "A test enum",
+                                        [Bit0, Bit3]>;
+
+def Bits0To3 : I32BitEnumAttrCaseGroup<"Bits0To3",
+                                       [Bit0, Bit1, Bit2, Bit3]>;
+
+def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum",
+                                      [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>;
+
+def BitEnum64_None : I64BitEnumAttrCaseNone<"None">;
+def BitEnum64_57   : I64BitEnumAttrCaseBit<"Bit57", 57>;
+def BitEnum64_1    : I64BitEnumAttrCaseBit<"Bit1", 1>;
+def BitEnum64_Test : I64BitEnumAttr<"BitEnum64_Test", "A 64-bit test enum",
+                                    [BitEnum64_None, BitEnum64_1, BitEnum64_57]>;
 
 def PrettyIntEnumCase1: I32EnumAttrCase<"Case1", 1, "case_one">;
 def PrettyIntEnumCase2: I32EnumAttrCase<"Case2", 2, "case_two">;


        


More information about the Mlir-commits mailing list