[Mlir-commits] [mlir] a29fffc - [mlir][spirv] Migrate to use specalized enum attributes

Lei Zhang llvmlistbot at llvm.org
Tue Aug 9 11:15:14 PDT 2022


Author: Lei Zhang
Date: 2022-08-09T14:14:54-04:00
New Revision: a29fffc4752a10494a74bb9f8db9b885c8aa5af1

URL: https://github.com/llvm/llvm-project/commit/a29fffc4752a10494a74bb9f8db9b885c8aa5af1
DIFF: https://github.com/llvm/llvm-project/commit/a29fffc4752a10494a74bb9f8db9b885c8aa5af1.diff

LOG: [mlir][spirv] Migrate to use specalized enum attributes

Previously we are using IntegerAttr to back all SPIR-V enum
attributes. Therefore we all such attributes are showed like
IntegerAttr in IRs, which is barely readable and breaks
roundtripability of the IR. This commit changes to use
`EnumAttr` as the base directly so that we can have separate
attribute definitions and better IR printing.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D131311

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
    mlir/include/mlir/TableGen/Attribute.h
    mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
    mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/TableGen/Attribute.cpp
    mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
    mlir/test/Conversion/GPUToSPIRV/simple.mlir
    mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
    mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
    mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
    mlir/test/Dialect/SPIRV/IR/availability.mlir
    mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir
    mlir/test/Dialect/SPIRV/IR/group-ops.mlir
    mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
    mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
    mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
    mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
    mlir/test/Dialect/SPIRV/IR/target-env.mlir
    mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
    mlir/test/Target/SPIRV/barrier-ops.mlir
    mlir/test/Target/SPIRV/group-ops.mlir
    mlir/test/Target/SPIRV/non-uniform-ops.mlir
    mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
    mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
    mlir/utils/spirv/gen_spirv_dialect.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
index b4ac754c7bfbc..82310cf00bb57 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
@@ -56,7 +56,7 @@ class Availability {
   string instance = ?;
 }
 
-class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
+class MinVersionBase<string name, EnumAttr scheme, I32EnumAttrCase min>
     : Availability {
   let interfaceName = name;
 
@@ -69,13 +69,13 @@ class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
                   "std::max(*$overall, $instance)); "
     "} else { $overall = $instance; }}";
   let initializer = "::llvm::None";
-  let instanceType = scheme.cppNamespace # "::" # scheme.className;
+  let instanceType = scheme.cppNamespace # "::" # scheme.enum.className;
 
-  let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
+  let instance = scheme.cppNamespace # "::" # scheme.enum.className # "::" #
                  min.symbol;
 }
 
-class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
+class MaxVersionBase<string name, EnumAttr scheme, I32EnumAttrCase max>
     : Availability {
   let interfaceName = name;
 
@@ -88,9 +88,9 @@ class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
                   "std::min(*$overall, $instance)); "
     "} else { $overall = $instance; }}";
   let initializer = "::llvm::None";
-  let instanceType = scheme.cppNamespace # "::" # scheme.className;
+  let instanceType = scheme.cppNamespace # "::" # scheme.enum.className;
 
-  let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
+  let instance = scheme.cppNamespace # "::" # scheme.enum.className # "::" #
                  max.symbol;
 }
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
index e00d8fa0c842c..09e547df4705e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
@@ -77,8 +77,6 @@ def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> {
 
   let results = (outs);
 
-  let autogenSerialization = 0;
-
   let assemblyFormat = [{
     $execution_scope `,` $memory_scope `,` $memory_semantics attr-dict
   }];
@@ -129,8 +127,6 @@ def SPV_MemoryBarrierOp : SPV_Op<"MemoryBarrier", []> {
 
   let results = (outs);
 
-  let autogenSerialization = 0;
-
   let assemblyFormat = "$memory_scope `,` $memory_semantics attr-dict";
 }
 

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 394ce1e917dff..0cfb3ead3642c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -82,43 +82,31 @@ def SPIRV_Dialect : Dialect {
 // Utility definitions
 //===----------------------------------------------------------------------===//
 
-// A predicate that checks whether `$_self` is a known enum case for the
-// enum class with `name`.
-class SPV_IsKnownEnumCaseFor<string name> :
-    CPred<"::mlir::spirv::symbolize" # name # "("
-          "$_self.cast<IntegerAttr>().getValue().getZExtValue()).has_value()">;
-
 // Wrapper over base BitEnumAttr to set common fields.
-class SPV_BitEnumAttr<string name, string description,
-                      list<BitEnumAttrCaseBase> cases> :
-    I32BitEnumAttr<name, description, cases> {
-  let predicate = And<[
-    I32Attr.predicate,
-    SPV_IsKnownEnumCaseFor<name>,
-  ]>;
+class SPV_BitEnum<string name, string description,
+                  list<BitEnumAttrCaseBase> cases>
+    : I32BitEnumAttr<name, description, cases> {
+  let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::spirv";
 }
-
-// Wrapper over base I32EnumAttr to set common fields.
-class SPV_I32EnumAttr<string name, string description,
-                      list<I32EnumAttrCase> cases> :
-    I32EnumAttr<name, description, cases> {
-  let predicate = And<[
-    I32Attr.predicate,
-    SPV_IsKnownEnumCaseFor<name>,
-  ]>;
-  let cppNamespace = "::mlir::spirv";
+class SPV_BitEnumAttr<string name, string description, string mnemonic,
+                      list<BitEnumAttrCaseBase> cases> :
+    EnumAttr<SPIRV_Dialect, SPV_BitEnum<name, description, cases>, mnemonic> {
+  let assemblyFormat = "`<` $value `>`";
 }
 
 // Wrapper over base I32EnumAttr to set common fields.
-class SPV_Enum<string name, string description, list<I32EnumAttrCase> cases>
+class SPV_I32Enum<string name, string description,
+                  list<I32EnumAttrCase> cases>
     : I32EnumAttr<name, description, cases> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::spirv";
 }
-class SPV_EnumAttr<string name, string description, string mnemonic,
+class SPV_I32EnumAttr<string name, string description, string mnemonic,
                       list<I32EnumAttrCase> cases> :
-    EnumAttr<SPIRV_Dialect, SPV_Enum<name, description, cases>, mnemonic>;
+    EnumAttr<SPIRV_Dialect, SPV_I32Enum<name, description, cases>, mnemonic> {
+  let assemblyFormat = "`<` $value `>`";
+}
 
 //===----------------------------------------------------------------------===//
 // SPIR-V availability definitions
@@ -132,7 +120,8 @@ def SPV_V_1_4 : I32EnumAttrCase<"V_1_4", 4, "v1.4">;
 def SPV_V_1_5 : I32EnumAttrCase<"V_1_5", 5, "v1.5">;
 def SPV_V_1_6 : I32EnumAttrCase<"V_1_6", 6, "v1.6">;
 
-def SPV_VersionAttr : SPV_I32EnumAttr<"Version", "valid SPIR-V version", [
+def SPV_VersionAttr : SPV_I32EnumAttr<
+  "Version", "valid SPIR-V version", "version", [
     SPV_V_1_0, SPV_V_1_1, SPV_V_1_2, SPV_V_1_3, SPV_V_1_4, SPV_V_1_5,
     SPV_V_1_6]>;
 
@@ -284,7 +273,7 @@ def SPV_DT_Other         : I32EnumAttrCase<"Other", 3>;
 // Information missing.
 def SPV_DT_Unknown       : I32EnumAttrCase<"Unknown", 4>;
 
-def SPV_DeviceTypeAttr : SPV_EnumAttr<
+def SPV_DeviceTypeAttr : SPV_I32EnumAttr<
   "DeviceType", "valid SPIR-V device types", "device_type", [
     SPV_DT_Other, SPV_DT_IntegratedGPU, SPV_DT_DiscreteGPU,
     SPV_DT_CPU, SPV_DT_Unknown
@@ -300,7 +289,7 @@ def SPV_V_Qualcomm    : I32EnumAttrCase<"Qualcomm", 6>;
 def SPV_V_SwiftShader : I32EnumAttrCase<"SwiftShader", 7>;
 def SPV_V_Unknown     : I32EnumAttrCase<"Unknown", 0xff>;
 
-def SPV_VendorAttr : SPV_EnumAttr<
+def SPV_VendorAttr : SPV_I32EnumAttr<
   "Vendor", "recognized SPIR-V vendor strings", "vendor", [
     SPV_V_AMD, SPV_V_Apple, SPV_V_ARM, SPV_V_Imagination,
     SPV_V_Intel, SPV_V_NVIDIA, SPV_V_Qualcomm, SPV_V_SwiftShader,
@@ -418,7 +407,7 @@ def SPV_NV_ray_tracing_motion_blur       : I32EnumAttrCase<"SPV_NV_ray_tracing_m
 def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
 
 def SPV_ExtensionAttr :
-    SPV_EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
+    SPV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
       SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
       SPV_KHR_float_controls, SPV_KHR_physical_storage_buffer, SPV_KHR_multiview,
       SPV_KHR_no_integer_wrap_decoration, SPV_KHR_post_depth_coverage,
@@ -1402,7 +1391,7 @@ def SPV_C_ShaderStereoViewNV                          : I32EnumAttrCase<"ShaderS
 }
 
 def SPV_CapabilityAttr :
-    SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", [
+    SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
       SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16,
       SPV_C_Float64, SPV_C_Int64, SPV_C_Groups, SPV_C_Int16, SPV_C_Int8,
       SPV_C_Sampled1D, SPV_C_SampledBuffer, SPV_C_GroupNonUniform, SPV_C_ShaderLayer,
@@ -1514,7 +1503,7 @@ def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64",
 }
 
 def SPV_AddressingModelAttr :
-    SPV_I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
+    SPV_I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", "addressing_model", [
       SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
       SPV_AM_PhysicalStorageBuffer64
     ]>;
@@ -2049,7 +2038,7 @@ def SPV_BI_CullMaskKHR                 : I32EnumAttrCase<"CullMaskKHR", 6021> {
 }
 
 def SPV_BuiltInAttr :
-    SPV_I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [
+    SPV_I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", "built_in", [
       SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance,
       SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId,
       SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter,
@@ -2610,7 +2599,7 @@ def SPV_D_MediaBlockIOINTEL                  : I32EnumAttrCase<"MediaBlockIOINTE
 }
 
 def SPV_DecorationAttr :
-    SPV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [
+    SPV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", "decoration", [
       SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock,
       SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride,
       SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn,
@@ -2679,7 +2668,7 @@ def SPV_D_SubpassData : I32EnumAttrCase<"SubpassData", 6> {
 }
 
 def SPV_DimAttr :
-    SPV_I32EnumAttr<"Dim", "valid SPIR-V Dim", [
+    SPV_I32EnumAttr<"Dim", "valid SPIR-V Dim", "dim", [
       SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer,
       SPV_D_SubpassData
     ]>;
@@ -3093,7 +3082,7 @@ def SPV_EM_NamedBarrierCountINTEL           : I32EnumAttrCase<"NamedBarrierCount
 }
 
 def SPV_ExecutionModeAttr :
-    SPV_I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", [
+    SPV_I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", "execution_mode", [
       SPV_EM_Invocations, SPV_EM_SpacingEqual, SPV_EM_SpacingFractionalEven,
       SPV_EM_SpacingFractionalOdd, SPV_EM_VertexOrderCw, SPV_EM_VertexOrderCcw,
       SPV_EM_PixelCenterInteger, SPV_EM_OriginUpperLeft, SPV_EM_OriginLowerLeft,
@@ -3203,7 +3192,7 @@ def SPV_EM_CallableKHR            : I32EnumAttrCase<"CallableKHR", 5318> {
 }
 
 def SPV_ExecutionModelAttr :
-    SPV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", [
+    SPV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", "execution_model", [
       SPV_EM_Vertex, SPV_EM_TessellationControl, SPV_EM_TessellationEvaluation,
       SPV_EM_Geometry, SPV_EM_Fragment, SPV_EM_GLCompute, SPV_EM_Kernel,
       SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationKHR, SPV_EM_IntersectionKHR,
@@ -3222,7 +3211,7 @@ def SPV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
 }
 
 def SPV_FunctionControlAttr :
-    SPV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
+    SPV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
       SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const,
       SPV_FC_OptNoneINTEL
     ]>;
@@ -3268,7 +3257,7 @@ def SPV_GO_PartitionedExclusiveScanNV : I32EnumAttrCase<"PartitionedExclusiveSca
 }
 
 def SPV_GroupOperationAttr :
-    SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", [
+    SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", "group_operation", [
       SPV_GO_Reduce, SPV_GO_InclusiveScan, SPV_GO_ExclusiveScan,
       SPV_GO_ClusteredReduce, SPV_GO_PartitionedReduceNV,
       SPV_GO_PartitionedInclusiveScanNV, SPV_GO_PartitionedExclusiveScanNV
@@ -3482,7 +3471,7 @@ def SPV_IF_R64i         : I32EnumAttrCase<"R64i", 41> {
 }
 
 def SPV_ImageFormatAttr :
-    SPV_I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [
+    SPV_I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", "image_format", [
       SPV_IF_Unknown, SPV_IF_Rgba32f, SPV_IF_Rgba16f, SPV_IF_R32f, SPV_IF_Rgba8,
       SPV_IF_Rgba8Snorm, SPV_IF_Rg32f, SPV_IF_Rg16f, SPV_IF_R11fG11fB10f,
       SPV_IF_R16f, SPV_IF_Rgba16, SPV_IF_Rgb10A2, SPV_IF_Rg16, SPV_IF_Rg8,
@@ -3561,7 +3550,7 @@ def SPV_IO_Nontemporal        : I32BitEnumAttrCaseBit<"Nontemporal", 14> {
 }
 
 def SPV_ImageOperandsAttr :
-    SPV_BitEnumAttr<"ImageOperands", "valid SPIR-V ImageOperands", [
+    SPV_BitEnumAttr<"ImageOperands", "valid SPIR-V ImageOperands", "image_operands", [
       SPV_IO_None, SPV_IO_Bias, SPV_IO_Lod, SPV_IO_Grad, SPV_IO_ConstOffset,
       SPV_IO_Offset, SPV_IO_ConstOffsets, SPV_IO_Sample, SPV_IO_MinLod,
       SPV_IO_MakeTexelAvailable, SPV_IO_MakeTexelVisible, SPV_IO_NonPrivateTexel,
@@ -3587,7 +3576,7 @@ def SPV_LT_LinkOnceODR : I32EnumAttrCase<"LinkOnceODR", 2> {
 }
 
 def SPV_LinkageTypeAttr :
-    SPV_I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [
+    SPV_I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", "linkage_type", [
       SPV_LT_Export, SPV_LT_Import, SPV_LT_LinkOnceODR
     ]>;
 
@@ -3679,7 +3668,7 @@ def SPV_LC_NoFusionINTEL             : I32BitEnumAttrCaseBit<"NoFusionINTEL", 23
 }
 
 def SPV_LoopControlAttr :
-    SPV_BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", [
+    SPV_BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", "loop_control", [
       SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite,
       SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations,
       SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount,
@@ -3725,7 +3714,7 @@ def SPV_MA_NoAliasINTELMask     : I32BitEnumAttrCaseBit<"NoAliasINTELMask", 17>
 }
 
 def SPV_MemoryAccessAttr :
-    SPV_BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [
+    SPV_BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", "memory_access", [
       SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal,
       SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible,
       SPV_MA_NonPrivatePointer, SPV_MA_AliasScopeINTELMask, SPV_MA_NoAliasINTELMask
@@ -3754,7 +3743,7 @@ def SPV_MM_Vulkan  : I32EnumAttrCase<"Vulkan", 3> {
 }
 
 def SPV_MemoryModelAttr :
-    SPV_I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
+    SPV_I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", "memory_model", [
       SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan
     ]>;
 
@@ -3803,7 +3792,7 @@ def SPV_MS_Volatile               : I32BitEnumAttrCaseBit<"Volatile", 15> {
 }
 
 def SPV_MemorySemanticsAttr :
-    SPV_BitEnumAttr<"MemorySemantics", "valid SPIR-V MemorySemantics", [
+    SPV_BitEnumAttr<"MemorySemantics", "valid SPIR-V MemorySemantics", "memory_semantics", [
       SPV_MS_None, SPV_MS_Acquire, SPV_MS_Release, SPV_MS_AcquireRelease,
       SPV_MS_SequentiallyConsistent, SPV_MS_UniformMemory, SPV_MS_SubgroupMemory,
       SPV_MS_WorkgroupMemory, SPV_MS_CrossWorkgroupMemory,
@@ -3829,7 +3818,7 @@ def SPV_S_ShaderCallKHR : I32EnumAttrCase<"ShaderCallKHR", 6> {
 }
 
 def SPV_ScopeAttr :
-    SPV_I32EnumAttr<"Scope", "valid SPIR-V Scope", [
+    SPV_I32EnumAttr<"Scope", "valid SPIR-V Scope", "scope", [
       SPV_S_CrossDevice, SPV_S_Device, SPV_S_Workgroup, SPV_S_Subgroup,
       SPV_S_Invocation, SPV_S_QueueFamily, SPV_S_ShaderCallKHR
     ]>;
@@ -3839,7 +3828,7 @@ def SPV_SC_Flatten     : I32BitEnumAttrCaseBit<"Flatten", 0>;
 def SPV_SC_DontFlatten : I32BitEnumAttrCaseBit<"DontFlatten", 1>;
 
 def SPV_SelectionControlAttr :
-    SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [
+    SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", "selection_control", [
       SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten
     ]>;
 
@@ -3947,7 +3936,7 @@ def SPV_SC_HostOnlyINTEL           : I32EnumAttrCase<"HostOnlyINTEL", 5937> {
 }
 
 def SPV_StorageClassAttr :
-    SPV_I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", [
+    SPV_I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", "storage_class", [
       SPV_SC_UniformConstant, SPV_SC_Input, SPV_SC_Uniform, SPV_SC_Output,
       SPV_SC_Workgroup, SPV_SC_CrossWorkgroup, SPV_SC_Private, SPV_SC_Function,
       SPV_SC_Generic, SPV_SC_PushConstant, SPV_SC_AtomicCounter, SPV_SC_Image,
@@ -3965,34 +3954,32 @@ def SPV_IDI_NoDepth      : I32EnumAttrCase<"NoDepth", 0>;
 def SPV_IDI_IsDepth      : I32EnumAttrCase<"IsDepth", 1>;
 def SPV_IDI_DepthUnknown : I32EnumAttrCase<"DepthUnknown", 2>;
 
-def SPV_DepthAttr :
-    SPV_I32EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification",
-      [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]>;
+def SPV_DepthAttr : SPV_I32EnumAttr<
+  "ImageDepthInfo", "valid SPIR-V Image Depth specification",
+  "image_depth_info", [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]>;
 
 def SPV_IAI_NonArrayed : I32EnumAttrCase<"NonArrayed", 0>;
 def SPV_IAI_Arrayed    : I32EnumAttrCase<"Arrayed", 1>;
 
-def SPV_ArrayedAttr :
-    SPV_I32EnumAttr<
-      "ImageArrayedInfo", "valid SPIR-V Image Arrayed specification",
-      [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]>;
+def SPV_ArrayedAttr : SPV_I32EnumAttr<
+  "ImageArrayedInfo", "valid SPIR-V Image Arrayed specification",
+  "image_arrayed_info", [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]>;
 
 def SPV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>;
 def SPV_ISI_MultiSampled  : I32EnumAttrCase<"MultiSampled", 1>;
 
-def SPV_SamplingAttr:
-    SPV_I32EnumAttr<
-      "ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
-      [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]>;
+def SPV_SamplingAttr: SPV_I32EnumAttr<
+  "ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
+  "image_sampling_info", [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]>;
 
 def SPV_ISUI_SamplerUnknown : I32EnumAttrCase<"SamplerUnknown", 0>;
 def SPV_ISUI_NeedSampler    : I32EnumAttrCase<"NeedSampler", 1>;
 def SPV_ISUI_NoSampler      : I32EnumAttrCase<"NoSampler", 2>;
 
-def SPV_SamplerUseAttr:
-    SPV_I32EnumAttr<
-      "ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification",
-      [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
+def SPV_SamplerUseAttr: SPV_I32EnumAttr<
+  "ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification",
+  "image_sampler_use_info",
+  [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
 
 //===----------------------------------------------------------------------===//
 // SPIR-V attribute definitions
@@ -4326,7 +4313,7 @@ def SPV_OC_OpAssumeTrueKHR             : I32EnumAttrCase<"OpAssumeTrueKHR", 5630
 def SPV_OC_OpAtomicFAddEXT             : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
 
 def SPV_OpcodeAttr :
-    SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
+    SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
       SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource,
       SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString,
       SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,

diff  --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 6b8ceb1771f44..f2098bdb015ee 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -113,6 +113,9 @@ class Attribute : public AttrConstraint {
 
   // Returns the dialect for the attribute if defined.
   Dialect getDialect() const;
+
+  // Returns the TableGen definition this Attribute was constructed from.
+  const llvm::Record &getDef() const;
 };
 
 // Wrapper class providing helper methods for accessing MLIR constant attribute

diff  --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 2d8504093e3ce..0ba6708e027b7 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -15,8 +15,8 @@
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
 #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/StringExtras.h"

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 7f6c50c892e87..2ccdfce32cf46 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -643,15 +644,15 @@ class ExecutionModePattern
     // this entry point's execution mode. We set it to be:
     //   __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
     ModuleOp module = op->getParentOfType<ModuleOp>();
-    IntegerAttr executionModeAttr = op.execution_modeAttr();
+    spirv::ExecutionModeAttr executionModeAttr = op.execution_modeAttr();
     std::string moduleName;
     if (module.getName().has_value())
       moduleName = "_" + module.getName().value().str();
     else
       moduleName = "";
-    std::string executionModeInfoName =
-        llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName,
-                      op.fn().str(), executionModeAttr.getValue());
+    std::string executionModeInfoName = llvm::formatv(
+        "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.fn().str(),
+        static_cast<uint32_t>(executionModeAttr.getValue()));
 
     MLIRContext *context = rewriter.getContext();
     OpBuilder::InsertionGuard guard(rewriter);
@@ -684,8 +685,10 @@ class ExecutionModePattern
     // Initialize the struct and set the execution mode value.
     rewriter.setInsertionPoint(block, block->begin());
     Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
-    Value executionMode =
-        rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
+    Value executionMode = rewriter.create<LLVM::ConstantOp>(
+        loc, llvmI32Type,
+        rewriter.getI32IntegerAttr(
+            static_cast<uint32_t>(executionModeAttr.getValue())));
     structValue = rewriter.create<LLVM::InsertValueOp>(
         loc, structType, structValue, executionMode,
         ArrayAttr::get(context,
@@ -1391,8 +1394,8 @@ class VectorShufflePattern
     auto llvmI32Type = IntegerType::get(context, 32);
     Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
     for (unsigned i = 0; i < componentsArray.size(); i++) {
-      if (componentsArray[i].isa<IntegerAttr>())
-        op.emitError("unable to support non-constant component");
+      if (!componentsArray[i].isa<IntegerAttr>())
+        return op.emitError("unable to support non-constant component");
 
       int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
       if (indexVal == -1)

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 9a0771ec6dc05..f7758f904675f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -174,19 +175,16 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
   NamedAttrList attr;
   auto loc = parser.getCurrentLocation();
   if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
-                            attrName, attr)) {
+                            attrName, attr))
     return failure();
-  }
-  if (!attrVal.isa<StringAttr>()) {
+  if (!attrVal.isa<StringAttr>())
     return parser.emitError(loc, "expected ")
            << attrName << " attribute specified as string";
-  }
   auto attrOptional =
       spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
-  if (!attrOptional) {
+  if (!attrOptional)
     return parser.emitError(loc, "invalid ")
            << attrName << " attribute specification: " << attrVal;
-  }
   value = *attrOptional;
   return success();
 }
@@ -194,50 +192,52 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
 /// Parses the next string attribute in `parser` as an enumerant of the given
 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
 /// attribute with the enum class's name as attribute name.
-template <typename EnumClass>
+template <typename EnumAttrClass,
+          typename EnumClass = typename EnumAttrClass::ValueType>
 static ParseResult
 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
                  StringRef attrName = spirv::attributeName<EnumClass>()) {
-  if (parseEnumStrAttr(value, parser)) {
+  if (parseEnumStrAttr(value, parser))
     return failure();
-  }
-  state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
-                                   llvm::bit_cast<int32_t>(value)));
+  state.addAttribute(attrName,
+                     parser.getBuilder().getAttr<EnumAttrClass>(value));
   return success();
 }
 
 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
 /// the enum class's name as attribute name.
-template <typename EnumClass>
+template <typename EnumAttrClass,
+          typename EnumClass = typename EnumAttrClass::ValueType>
 static ParseResult
 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
                      OperationState &state,
                      StringRef attrName = spirv::attributeName<EnumClass>()) {
-  if (parseEnumKeywordAttr(value, parser)) {
+  if (parseEnumKeywordAttr(value, parser))
     return failure();
-  }
-  state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
-                                   llvm::bit_cast<int32_t>(value)));
+  state.addAttribute(attrName,
+                     parser.getBuilder().getAttr<EnumAttrClass>(value));
   return success();
 }
 
 /// Parses Function, Selection and Loop control attributes. If no control is
 /// specified, "None" is used as a default.
-template <typename EnumClass>
+template <typename EnumAttrClass, typename EnumClass>
 static ParseResult
 parseControlAttribute(OpAsmParser &parser, OperationState &state,
                       StringRef attrName = spirv::attributeName<EnumClass>()) {
   if (succeeded(parser.parseOptionalKeyword(kControl))) {
     EnumClass control;
-    if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
+    if (parser.parseLParen() ||
+        parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
         parser.parseRParen())
       return failure();
     return success();
   }
   // Set control to "None" otherwise.
   Builder builder = parser.getBuilder();
-  state.addAttribute(attrName, builder.getI32IntegerAttr(0));
+  state.addAttribute(attrName,
+                     builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
   return success();
 }
 
@@ -256,10 +256,9 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
-  if (parseEnumStrAttr(memoryAccessAttr, parser, state,
-                       kMemoryAccessAttrName)) {
+  if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
+                                                kMemoryAccessAttrName))
     return failure();
-  }
 
   if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
@@ -287,10 +286,9 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
   }
 
   spirv::MemoryAccess memoryAccessAttr;
-  if (parseEnumStrAttr(memoryAccessAttr, parser, state,
-                       kSourceMemoryAccessAttrName)) {
+  if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
+                                                kSourceMemoryAccessAttrName))
     return failure();
-  }
 
   if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
     // Parse integer attribute for alignment.
@@ -479,15 +477,15 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
     return success();
   }
 
-  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
-  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
+  auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
 
   if (!memAccess) {
     return memoryOp.emitOpError("invalid memory access specifier: ")
-           << memAccessVal;
+           << memAccessAttr;
   }
 
-  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+  if (spirv::bitEnumContains(memAccess.getValue(),
+                             spirv::MemoryAccess::Aligned)) {
     if (!op->getAttr(kAlignmentAttrName)) {
       return memoryOp.emitOpError("missing alignment value");
     }
@@ -523,15 +521,15 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
     return success();
   }
 
-  auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
-  auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
+  auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
 
   if (!memAccess) {
     return memoryOp.emitOpError("invalid memory access specifier: ")
-           << memAccessVal;
+           << memAccess;
   }
 
-  if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+  if (spirv::bitEnumContains(memAccess.getValue(),
+                             spirv::MemoryAccess::Aligned)) {
     if (!op->getAttr(kSourceAlignmentAttrName)) {
       return memoryOp.emitOpError("missing alignment value");
     }
@@ -770,8 +768,10 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
   OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
   Type type;
   SMLoc loc;
-  if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
-      parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
+  if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
+                                         kMemoryScopeAttrName) ||
+      parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
+                                                   kSemanticsAttrName) ||
       parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
       parser.getCurrentLocation(&loc) || parser.parseColonType(type))
     return failure();
@@ -793,14 +793,11 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
 // Prints an atomic update op.
 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
   printer << " \"";
-  auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
-  printer << spirv::stringifyScope(
-                 static_cast<spirv::Scope>(scopeAttr.getInt()))
-          << "\" \"";
-  auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
-  printer << spirv::stringifyMemorySemantics(
-                 static_cast<spirv::MemorySemantics>(
-                     memorySemanticsAttr.getInt()))
+  auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
+  printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
+  auto memorySemanticsAttr =
+      op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
+  printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
           << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
 }
 
@@ -834,8 +831,9 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
                              "pointer operand's pointee type ")
              << elementType << ", but found " << valueType;
   }
-  auto memorySemantics = static_cast<spirv::MemorySemantics>(
-      op->getAttrOfType<IntegerAttr>(kSemanticsAttrName).getInt());
+  auto memorySemantics =
+      op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+          .getValue();
   if (failed(verifyMemorySemantics(op, memorySemantics))) {
     return failure();
   }
@@ -847,10 +845,10 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
   spirv::Scope executionScope;
   spirv::GroupOperation groupOperation;
   OpAsmParser::UnresolvedOperand valueInfo;
-  if (parseEnumStrAttr(executionScope, parser, state,
-                       kExecutionScopeAttrName) ||
-      parseEnumStrAttr(groupOperation, parser, state,
-                       kGroupOperationAttrName) ||
+  if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
+                                         kExecutionScopeAttrName) ||
+      parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
+                                                  kGroupOperationAttrName) ||
       parser.parseOperand(valueInfo))
     return failure();
 
@@ -880,15 +878,17 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
 
 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
                                              OpAsmPrinter &printer) {
-  printer << " \""
-          << stringifyScope(static_cast<spirv::Scope>(
-                 groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
-                     .getInt()))
-          << "\" \""
-          << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
-                 groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
-                     .getInt()))
-          << "\" " << groupOp->getOperand(0);
+  printer
+      << " \""
+      << stringifyScope(
+             groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+                 .getValue())
+      << "\" \""
+      << stringifyGroupOperation(groupOp
+                                     ->getAttrOfType<spirv::GroupOperationAttr>(
+                                         kGroupOperationAttrName)
+                                     .getValue())
+      << "\" " << groupOp->getOperand(0);
 
   if (groupOp->getNumOperands() > 1)
     printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
@@ -896,14 +896,16 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
 }
 
 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
-  spirv::Scope scope = static_cast<spirv::Scope>(
-      groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
+  spirv::Scope scope =
+      groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+          .getValue();
   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
     return groupOp->emitOpError(
         "execution scope must be 'Workgroup' or 'Subgroup'");
 
-  spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
-      groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
+  spirv::GroupOperation operation =
+      groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
+          .getValue();
   if (operation == spirv::GroupOperation::ClusteredReduce &&
       groupOp->getNumOperands() == 1)
     return groupOp->emitOpError("cluster size operand must be provided for "
@@ -1145,11 +1147,12 @@ static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
   spirv::MemorySemantics equalSemantics, unequalSemantics;
   SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
   Type type;
-  if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
-      parseEnumStrAttr(equalSemantics, parser, state,
-                       kEqualSemanticsAttrName) ||
-      parseEnumStrAttr(unequalSemantics, parser, state,
-                       kUnequalSemanticsAttrName) ||
+  if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
+                                         kMemoryScopeAttrName) ||
+      parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+          equalSemantics, parser, state, kEqualSemanticsAttrName) ||
+      parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+          unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
       parser.parseOperandList(operandInfo, 3))
     return failure();
 
@@ -1267,8 +1270,10 @@ ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
   spirv::MemorySemantics semantics;
   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
   Type type;
-  if (parseEnumStrAttr(memoryScope, parser, result, kMemoryScopeAttrName) ||
-      parseEnumStrAttr(semantics, parser, result, kSemanticsAttrName) ||
+  if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
+                                         kMemoryScopeAttrName) ||
+      parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
+                                                   kSemanticsAttrName) ||
       parser.parseOperandList(operandInfo, 2))
     return failure();
 
@@ -2075,7 +2080,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
   SmallVector<Attribute, 4> interfaceVars;
 
   FlatSymbolRefAttr fn;
-  if (parseEnumStrAttr(execModel, parser, result) ||
+  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
       parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
     return failure();
   }
@@ -2132,7 +2137,7 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
   spirv::ExecutionMode execMode;
   Attribute fn;
   if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
-      parseEnumStrAttr(execMode, parser, result)) {
+      parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
     return failure();
   }
 
@@ -2220,7 +2225,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
 
   // Parse the optional function control keyword.
   spirv::FunctionControl fnControl;
-  if (parseEnumStrAttr(fnControl, parser, result))
+  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
     return failure();
 
   // If additional attributes are present, parse them.
@@ -2308,7 +2313,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
                      builder.getStringAttr(name));
   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
   state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
-                     builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
+                     builder.getAttr<spirv::FunctionControlAttr>(control));
   state.attributes.append(attrs.begin(), attrs.end());
   state.addRegion();
 }
@@ -2997,14 +3002,14 @@ LogicalResult spirv::LoadOp::verify() {
 //===----------------------------------------------------------------------===//
 
 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
-  state.addAttribute("loop_control",
-                     builder.getI32IntegerAttr(
-                         static_cast<uint32_t>(spirv::LoopControl::None)));
+  state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
+                                         spirv::LoopControl::None));
   state.addRegion();
 }
 
 ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
-  if (parseControlAttribute<spirv::LoopControl>(parser, result))
+  if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
+                                                                        result))
     return failure();
   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
 }
@@ -3195,9 +3200,9 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
                             Optional<StringRef> name) {
   state.addAttribute(
       "addressing_model",
-      builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
-  state.addAttribute("memory_model", builder.getI32IntegerAttr(
-                                         static_cast<int32_t>(memoryModel)));
+      builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
+  state.addAttribute("memory_model",
+                     builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
   OpBuilder::InsertionGuard guard(builder);
   builder.createBlock(state.addRegion());
   if (vceTriple)
@@ -3219,8 +3224,10 @@ ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
   // Parse attributes
   spirv::AddressingModel addrModel;
   spirv::MemoryModel memoryModel;
-  if (::parseEnumKeywordAttr(addrModel, parser, result) ||
-      ::parseEnumKeywordAttr(memoryModel, parser, result))
+  if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
+                                                         result) ||
+      ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
+                                                     result))
     return failure();
 
   if (succeeded(parser.parseOptionalKeyword("requires"))) {
@@ -3401,7 +3408,8 @@ LogicalResult spirv::SelectOp::verify() {
 
 ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
                                       OperationState &result) {
-  if (parseControlAttribute<spirv::SelectionControl>(parser, result))
+  if (parseControlAttribute<spirv::SelectionControlAttr,
+                            spirv::SelectionControl>(parser, result))
     return failure();
   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
 }
@@ -3666,8 +3674,8 @@ ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
       return failure();
   }
 
-  auto attr = parser.getBuilder().getI32IntegerAttr(
-      llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
+  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
+      ptrType.getStorageClass());
   result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
 
   return success();

diff  --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index a8240a39ffb92..0b2841c296d3f 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -132,6 +132,8 @@ Dialect Attribute::getDialect() const {
   return Dialect(nullptr);
 }
 
+const llvm::Record &Attribute::getDef() const { return *def; }
+
 ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
   assert(def->isSubClassOf("ConstantAttr") &&
          "must be subclass of TableGen 'ConstantAttr' class");

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 9566b9ed1bbe8..c6787d79ffe7b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -12,6 +12,7 @@
 
 #include "Deserializer.h"
 
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Location.h"
@@ -406,35 +407,6 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
   return success();
 }
 
-template <>
-LogicalResult
-Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
-  if (operands.size() != 3) {
-    return emitError(
-        unknownLoc,
-        "OpControlBarrier must have execution scope <id>, memory scope <id> "
-        "and memory semantics <id>");
-  }
-
-  SmallVector<IntegerAttr, 3> argAttrs;
-  for (auto operand : operands) {
-    auto argAttr = getConstantInt(operand);
-    if (!argAttr) {
-      return emitError(unknownLoc,
-                       "expected 32-bit integer constant from <id> ")
-             << operand << " for OpControlBarrier";
-    }
-    argAttrs.push_back(argAttr);
-  }
-
-  opBuilder.create<spirv::ControlBarrierOp>(
-      unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
-      argAttrs[1].cast<spirv::ScopeAttr>(),
-      argAttrs[2].cast<spirv::MemorySemanticsAttr>());
-
-  return success();
-}
-
 template <>
 LogicalResult
 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
@@ -477,31 +449,6 @@ Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
   return success();
 }
 
-template <>
-LogicalResult
-Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
-  if (operands.size() != 2) {
-    return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
-                                 "and memory semantics <id>");
-  }
-
-  SmallVector<IntegerAttr, 2> argAttrs;
-  for (auto operand : operands) {
-    auto argAttr = getConstantInt(operand);
-    if (!argAttr) {
-      return emitError(unknownLoc,
-                       "expected 32-bit integer constant from <id> ")
-             << operand << " for OpMemoryBarrier";
-    }
-    argAttrs.push_back(argAttr);
-  }
-
-  opBuilder.create<spirv::MemoryBarrierOp>(
-      unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
-      argAttrs[1].cast<spirv::MemorySemanticsAttr>());
-  return success();
-}
-
 template <>
 LogicalResult
 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
@@ -538,8 +485,9 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
 
   if (wordIndex < words.size()) {
     auto attrValue = words[wordIndex++];
-    attributes.push_back(opBuilder.getNamedAttr(
-        "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
+    auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
+        static_cast<spirv::MemoryAccess>(attrValue));
+    attributes.push_back(opBuilder.getNamedAttr("memory_access", attr));
     isAlignedAttr = (attrValue == 2);
   }
 
@@ -549,9 +497,10 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
   }
 
   if (wordIndex < words.size()) {
-    attributes.push_back(opBuilder.getNamedAttr(
-        "source_memory_access",
-        opBuilder.getI32IntegerAttr(words[wordIndex++])));
+    auto attrValue = words[wordIndex++];
+    auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
+        static_cast<spirv::MemoryAccess>(attrValue));
+    attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
   }
 
   if (wordIndex < words.size()) {

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 383bdc4c905a9..e4cfc4b380e46 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -216,10 +216,11 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
 
   (*module)->setAttr(
       "addressing_model",
-      opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
-  (*module)->setAttr(
-      "memory_model",
-      opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
+      opBuilder.getAttr<spirv::AddressingModelAttr>(
+          static_cast<spirv::AddressingModel>(operands.front())));
+  (*module)->setAttr("memory_model",
+                     opBuilder.getAttr<spirv::MemoryModelAttr>(
+                         static_cast<spirv::MemoryModel>(operands.back())));
 
   return success();
 }

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ad1724f9269f0..22fff80440048 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -13,6 +13,7 @@
 #include "Serializer.h"
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/IR/RegionGraphTraits.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
@@ -277,8 +278,8 @@ LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
   operands.push_back(resultID);
   auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
   if (attr) {
-    operands.push_back(static_cast<uint32_t>(
-        attr.cast<IntegerAttr>().getValue().getZExtValue()));
+    operands.push_back(
+        static_cast<uint32_t>(attr.cast<spirv::StorageClassAttr>().getValue()));
   }
   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
   for (auto arg : op.getODSOperands(0)) {
@@ -565,27 +566,6 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
   return success();
 }
 
-template <>
-LogicalResult
-Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
-  StringRef argNames[] = {"execution_scope", "memory_scope",
-                          "memory_semantics"};
-  SmallVector<uint32_t, 3> operands;
-
-  for (auto argName : argNames) {
-    auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
-    auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
-    if (!operand) {
-      return failure();
-    }
-    operands.push_back(operand);
-  }
-
-  encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
-                        operands);
-  return success();
-}
-
 template <>
 LogicalResult
 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
@@ -615,25 +595,6 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
   return success();
 }
 
-template <>
-LogicalResult
-Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
-  StringRef argNames[] = {"memory_scope", "memory_semantics"};
-  SmallVector<uint32_t, 2> operands;
-
-  for (auto argName : argNames) {
-    auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
-    auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
-    if (!operand) {
-      return failure();
-    }
-    operands.push_back(operand);
-  }
-
-  encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands);
-  return success();
-}
-
 template <>
 LogicalResult
 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
@@ -674,8 +635,8 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
   }
 
   if (auto attr = op->getAttr("memory_access")) {
-    operands.push_back(static_cast<uint32_t>(
-        attr.cast<IntegerAttr>().getValue().getZExtValue()));
+    operands.push_back(
+        static_cast<uint32_t>(attr.cast<spirv::MemoryAccessAttr>().getValue()));
   }
 
   elidedAttrs.push_back("memory_access");
@@ -688,8 +649,8 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
   elidedAttrs.push_back("alignment");
 
   if (auto attr = op->getAttr("source_memory_access")) {
-    operands.push_back(static_cast<uint32_t>(
-        attr.cast<IntegerAttr>().getValue().getZExtValue()));
+    operands.push_back(
+        static_cast<uint32_t>(attr.cast<spirv::MemoryAccessAttr>().getValue()));
   }
 
   elidedAttrs.push_back("source_memory_access");

diff  --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 47542a41bbb71..b8be9433e94e8 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
@@ -23,6 +24,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/Debug.h"
+#include <cstdint>
 
 #define DEBUG_TYPE "spirv-serialization"
 
@@ -192,8 +194,11 @@ void Serializer::processExtension() {
 }
 
 void Serializer::processMemoryModel() {
-  uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
-  uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
+  auto mm = static_cast<uint32_t>(
+      module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue());
+  auto am = static_cast<uint32_t>(
+      module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model")
+          .getValue());
 
   encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
 }

diff  --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
index 5d1e9168bf607..c6b45ba491905 100644
--- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
@@ -112,7 +112,7 @@ module attributes {gpu.container_module} {
     // CHECK-LABEL: spv.func @barrier
     gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32>) kernel
       attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
-      // CHECK: spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+      // CHECK: spv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
       gpu.barrier
       gpu.return
     }

diff  --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
index 1a7ef8e0b32c8..3b0af88a299e6 100644
--- a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
+++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
@@ -32,7 +32,7 @@ module attributes {
 // CHECK:        %[[ADD:.+]] = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[VAL]] : i32
 
 // CHECK:        %[[OUTPTR:.+]] = spv.AccessChain %[[OUTPUT]][%[[ZERO]], %[[ZERO]]]
-// CHECK:        %[[ELECT:.+]] = spv.GroupNonUniformElect Subgroup : i1
+// CHECK:        %[[ELECT:.+]] = spv.GroupNonUniformElect <Subgroup> : i1
 
 // CHECK:        spv.mlir.selection {
 // CHECK:          spv.BranchConditional %[[ELECT]], ^bb1, ^bb2

diff  --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
index 5f474208ec47f..1e3908618e986 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
@@ -1,32 +1,30 @@
 // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=vulkan' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=VULKAN
 
 // Vulkan Mappings:
-//   0 -> StorageBuffer (12)
-//   1 -> Generic (8)
-//   3 -> Workgroup (4)
-//   4 -> Uniform (2)
-// TODO: create a StorageClass wrapper class so we can print the symbolc
-// storage class (instead of the backing IntegerAttr) and be able to
-// round trip the IR.
+//   0 -> StorageBuffer
+//   1 -> Generic
+//   2 -> [null]
+//   3 -> Workgroup
+//   4 -> Uniform
 
 // VULKAN-LABEL: func @operand_result
 func.func @operand_result() {
-  // VULKAN: memref<f32, 12 : i32>
+  // VULKAN: memref<f32, #spv.storage_class<StorageBuffer>>
   %0 = "dialect.memref_producer"() : () -> (memref<f32>)
-  // VULKAN: memref<4xi32, 8 : i32>
+  // VULKAN: memref<4xi32, #spv.storage_class<Generic>>
   %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
-  // VULKAN: memref<?x4xf16, 4 : i32>
+  // VULKAN: memref<?x4xf16, #spv.storage_class<Workgroup>>
   %2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
-  // VULKAN: memref<*xf16, 2 : i32>
+  // VULKAN: memref<*xf16, #spv.storage_class<Uniform>>
   %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
 
 
   "dialect.memref_consumer"(%0) : (memref<f32>) -> ()
-  // VULKAN: memref<4xi32, 8 : i32>
+  // VULKAN: memref<4xi32, #spv.storage_class<Generic>>
   "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
-  // VULKAN: memref<?x4xf16, 4 : i32>
+  // VULKAN: memref<?x4xf16, #spv.storage_class<Workgroup>>
   "dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
-  // VULKAN: memref<*xf16, 2 : i32>
+  // VULKAN: memref<*xf16, #spv.storage_class<Uniform>>
   "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
 
   return
@@ -36,7 +34,7 @@ func.func @operand_result() {
 
 // VULKAN-LABEL: func @type_attribute
 func.func @type_attribute() {
-  // VULKAN: attr = memref<i32, 8 : i32>
+  // VULKAN: attr = memref<i32, #spv.storage_class<Generic>>
   "dialect.memref_producer"() { attr = memref<i32, 1> } : () -> ()
   return
 }
@@ -45,9 +43,9 @@ func.func @type_attribute() {
 
 // VULKAN-LABEL: func @function_io
 func.func @function_io
-  // VULKAN-SAME: (%{{.+}}: memref<f64, 8 : i32>, %{{.+}}: memref<4xi32, 4 : i32>)
+  // VULKAN-SAME: (%{{.+}}: memref<f64, #spv.storage_class<Generic>>, %{{.+}}: memref<4xi32, #spv.storage_class<Workgroup>>)
   (%arg0: memref<f64, 1>, %arg1: memref<4xi32, 3>)
-  // VULKAN-SAME: -> (memref<f64, 8 : i32>, memref<4xi32, 4 : i32>)
+  // VULKAN-SAME: -> (memref<f64, #spv.storage_class<Generic>>, memref<4xi32, #spv.storage_class<Workgroup>>)
   -> (memref<f64, 1>, memref<4xi32, 3>) {
   return %arg0, %arg1: memref<f64, 1>, memref<4xi32, 3>
 }
@@ -57,8 +55,8 @@ func.func @function_io
 // VULKAN: func @region
 func.func @region(%cond: i1, %arg0: memref<f32, 1>) {
   scf.if %cond {
-    //      VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, 4 : i32>}
-    // VULKAN-SAME: (memref<f32, 8 : i32>) -> memref<f32, 8 : i32>
+    //      VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, #spv.storage_class<Workgroup>>}
+    // VULKAN-SAME: (memref<f32, #spv.storage_class<Generic>>) -> memref<f32, #spv.storage_class<Generic>>
     %0 = "dialect.memref_consumer"(%arg0) { attr = memref<i64, 3> } : (memref<f32, 1>) -> (memref<f32, 1>)
   }
   return

diff  --git a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
index ed7e8bc72c8ac..9889422fa31b8 100644
--- a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
@@ -14,7 +14,7 @@ func.func @atomic_and(%ptr : !spv.ptr<i32, StorageBuffer>, %value : i32) -> i32
 
 func.func @atomic_and(%ptr : !spv.ptr<f32, StorageBuffer>, %value : i32) -> i32 {
   // expected-error @+1 {{pointer operand must point to an integer value, found 'f32'}}
-  %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4 : i32} : (!spv.ptr<f32, StorageBuffer>, i32) -> (i32)
+  %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<f32, StorageBuffer>, i32) -> (i32)
   return %0 : i32
 }
 
@@ -23,7 +23,7 @@ func.func @atomic_and(%ptr : !spv.ptr<f32, StorageBuffer>, %value : i32) -> i32
 
 func.func @atomic_and(%ptr : !spv.ptr<i32, StorageBuffer>, %value : i64) -> i64 {
   // expected-error @+1 {{expected value to have the same type as the pointer operand's pointee type 'i32', but found 'i64'}}
-  %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = 2: i32, semantics = 0x8 : i32} : (!spv.ptr<i32, StorageBuffer>, i64) -> (i64)
+  %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, StorageBuffer>, i64) -> (i64)
   return %0 : i64
 }
 
@@ -51,7 +51,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32,
 
 func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64, %comparator: i32) -> i32 {
   // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
-  %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
+  %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
   return %0: i32
 }
 
@@ -59,7 +59,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64,
 
 func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i16) -> i32 {
   // expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}}
-  %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
+  %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
   return %0: i32
 }
 
@@ -67,7 +67,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32,
 
 func.func @atomic_compare_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32, %comparator: i32) -> i32 {
   // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
-  %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
+  %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
   return %0: i32
 }
 
@@ -87,7 +87,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value:
 
 func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i64, %comparator: i32) -> i32 {
   // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
-  %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
+  %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
   return %0: i32
 }
 
@@ -95,7 +95,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value:
 
 func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i16) -> i32 {
   // expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}}
-  %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
+  %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
   return %0: i32
 }
 
@@ -103,7 +103,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value:
 
 func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i64, Workgroup>, %value: i32, %comparator: i32) -> i32 {
   // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
-  %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
+  %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
   return %0: i32
 }
 
@@ -123,7 +123,7 @@ func.func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32) -> i32 {
 
 func.func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64) -> i32 {
   // expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
-  %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i32, Workgroup>, i64) -> (i32)
+  %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i64) -> (i32)
   return %0: i32
 }
 
@@ -131,7 +131,7 @@ func.func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64) -> i32 {
 
 func.func @atomic_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32) -> i32 {
   // expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
-  %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i64, Workgroup>, i32) -> (i32)
+  %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i64, Workgroup>, i32) -> (i32)
   return %0: i32
 }
 
@@ -253,7 +253,7 @@ func.func @atomic_fadd(%ptr : !spv.ptr<f32, StorageBuffer>, %value : f32) -> f32
 
 func.func @atomic_fadd(%ptr : !spv.ptr<i32, StorageBuffer>, %value : f32) -> f32 {
   // expected-error @+1 {{pointer operand must point to an float value, found 'i32'}}
-  %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4 : i32} : (!spv.ptr<i32, StorageBuffer>, f32) -> (f32)
+  %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, StorageBuffer>, f32) -> (f32)
   return %0 : f32
 }
 
@@ -261,7 +261,7 @@ func.func @atomic_fadd(%ptr : !spv.ptr<i32, StorageBuffer>, %value : f32) -> f32
 
 func.func @atomic_fadd(%ptr : !spv.ptr<f32, StorageBuffer>, %value : f64) -> f64 {
   // expected-error @+1 {{expected value to have the same type as the pointer operand's pointee type 'f32', but found 'f64'}}
-  %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = 2: i32, semantics = 0x8 : i32} : (!spv.ptr<f32, StorageBuffer>, f64) -> (f64)
+  %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope<Device>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<f32, StorageBuffer>, f64) -> (f64)
   return %0 : f64
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 089b70cf949bb..810fe53faa953 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -26,7 +26,7 @@ func.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
   // CHECK: max version: v1.6
   // CHECK: extensions: [ ]
   // CHECK: capabilities: [ [GroupNonUniformBallot] ]
-  %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
+  %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
   return %0: vector<4xi32>
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir b/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir
index 45d0a7430244e..931426f32848b 100644
--- a/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir
@@ -5,16 +5,17 @@
 //===----------------------------------------------------------------------===//
 
 func.func @control_barrier_0() -> () {
-  // CHECK:  spv.ControlBarrier Workgroup, Device, "Acquire|UniformMemory"
-  spv.ControlBarrier Workgroup, Device, "Acquire|UniformMemory"
+  // CHECK: spv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
+  spv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
   return
 }
 
 // -----
 
 func.func @control_barrier_1() -> () {
-  // expected-error @+1 {{expected string or keyword containing one of the following enum values}}
-  spv.ControlBarrier Something, Device, "Acquire|UniformMemory"
+  // expected-error @+2 {{to be one of}}
+  // expected-error @+1 {{failed to parse SPV_ScopeAttr}}
+  spv.ControlBarrier <Something>, <Device>, <Acquire|UniformMemory>
   return
 }
 
@@ -26,16 +27,16 @@ func.func @control_barrier_1() -> () {
 //===----------------------------------------------------------------------===//
 
 func.func @memory_barrier_0() -> () {
-  // CHECK: spv.MemoryBarrier Device, "Acquire|UniformMemory"
-  spv.MemoryBarrier Device, "Acquire|UniformMemory"
+  // CHECK: spv.MemoryBarrier <Device>, <Acquire|UniformMemory>
+  spv.MemoryBarrier <Device>, <Acquire|UniformMemory>
   return
 }
 
 // -----
 
 func.func @memory_barrier_1() -> () {
-  // CHECK: spv.MemoryBarrier Workgroup, Acquire
-  spv.MemoryBarrier Workgroup, Acquire
+  // CHECK: spv.MemoryBarrier <Workgroup>, <Acquire>
+  spv.MemoryBarrier <Workgroup>, <Acquire>
   return
 }
 
@@ -43,7 +44,7 @@ func.func @memory_barrier_1() -> () {
 
 func.func @memory_barrier_2() -> () {
  // expected-error @+1 {{expected at most one of these four memory constraints to be set: `Acquire`, `Release`,`AcquireRelease` or `SequentiallyConsistent`}}
-  spv.MemoryBarrier Device, "Acquire|Release"
+  spv.MemoryBarrier <Device>, <Acquire|Release>
   return
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
index 103e41016648b..a62f6fffa1616 100644
--- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
@@ -17,24 +17,24 @@ func.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
 //===----------------------------------------------------------------------===//
 
 func.func @group_broadcast_scalar(%value: f32, %localid: i32 ) -> f32 {
-  // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32
-  %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, i32
+  // CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, i32
+  %0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, i32
   return %0: f32
 }
 
 // -----
 
 func.func @group_broadcast_scalar_vector(%value: f32, %localid: vector<3xi32> ) -> f32 {
-  // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, vector<3xi32>
-  %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, vector<3xi32>
+  // CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, vector<3xi32>
+  %0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, vector<3xi32>
   return %0: f32
 }
 
 // -----
 
 func.func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32> ) -> vector<4xf32> {
-  // CHECK: spv.GroupBroadcast Subgroup %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
-  %0 = spv.GroupBroadcast Subgroup %value, %localid : vector<4xf32>, vector<3xi32>
+  // CHECK: spv.GroupBroadcast <Subgroup> %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
+  %0 = spv.GroupBroadcast <Subgroup> %value, %localid : vector<4xf32>, vector<3xi32>
   return %0: vector<4xf32>
 }
 
@@ -42,7 +42,7 @@ func.func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32>
 
 func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 {
   // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} 
-  %0 = spv.GroupBroadcast Device %value, %localid : f32, vector<3xi32>
+  %0 = spv.GroupBroadcast <Device> %value, %localid : f32, vector<3xi32>
   return %0: f32
 }
 
@@ -50,7 +50,7 @@ func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> )
 
 func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 {
   // expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
-  %0 = spv.GroupBroadcast Subgroup %value, %localid : f32, vector<3xf32>
+  %0 = spv.GroupBroadcast <Subgroup> %value, %localid : f32, vector<3xf32>
   return %0: f32
 }
 
@@ -58,7 +58,7 @@ func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3x
 
 func.func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> ) -> f32 {
   // expected-error @+1 {{localid is a vector and can be with only  2 or 3 components, actual number is 4}}
-  %0 = spv.GroupBroadcast Subgroup %value, %localid : f32, vector<4xi32>
+  %0 = spv.GroupBroadcast <Subgroup> %value, %localid : f32, vector<4xi32>
   return %0: f32
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 58e72570ed497..f2b3979ef5f6f 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -198,7 +198,7 @@ func.func @load_none_access() -> () {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   // CHECK: spv.Load
   // CHECK-SAME: ["None"]
-  %1 = "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<f32, Function>) -> (f32)
+  %1 = "spv.Load"(%0) {memory_access = #spv.memory_access<None>} : (!spv.ptr<f32, Function>) -> (f32)
   return
 }
 
@@ -207,7 +207,7 @@ func.func @volatile_load() -> () {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   // CHECK: spv.Load
   // CHECK-SAME: ["Volatile"]
-  %1 = "spv.Load"(%0) {memory_access = 1 : i32} : (!spv.ptr<f32, Function>) -> (f32)
+  %1 = "spv.Load"(%0) {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<f32, Function>) -> (f32)
   return
 }
 
@@ -216,7 +216,7 @@ func.func @aligned_load() -> () {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   // CHECK: spv.Load
   // CHECK-SAME: ["Aligned", 4]
-  %1 = "spv.Load"(%0) {memory_access = 2 : i32, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
+  %1 = "spv.Load"(%0) {memory_access = #spv.memory_access<Aligned>, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
   return
 }
 
@@ -225,7 +225,7 @@ func.func @volatile_aligned_load() -> () {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   // CHECK: spv.Load
   // CHECK-SAME: ["Volatile|Aligned", 4]
-  %1 = "spv.Load"(%0) {memory_access = 3 : i32, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
+  %1 = "spv.Load"(%0) {memory_access = #spv.memory_access<Volatile|Aligned>, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
   return
 }
 
@@ -588,7 +588,7 @@ func.func @copy_memory_invalid_maa() {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   %1 = spv.Variable : !spv.ptr<f32, Function>
   // expected-error @+1 {{missing alignment value}}
-  "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  "spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access<Aligned>} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
   spv.Return
 }
 
@@ -598,7 +598,7 @@ func.func @copy_memory_invalid_source_maa() {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   %1 = spv.Variable : !spv.ptr<f32, Function>
   // expected-error @+1 {{invalid alignment specification with non-aligned memory access specification}}
-  "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Volatile>, memory_access=#spv.memory_access<Aligned>, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
   spv.Return
 }
 
@@ -608,7 +608,7 @@ func.func @copy_memory_invalid_source_maa2() {
   %0 = spv.Variable : !spv.ptr<f32, Function>
   %1 = spv.Variable : !spv.ptr<f32, Function>
   // expected-error @+1 {{missing alignment value}}
-  "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Aligned>, memory_access=#spv.memory_access<Aligned>, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
   spv.Return
 }
 
@@ -619,16 +619,16 @@ func.func @copy_memory_print_maa() {
   %1 = spv.Variable : !spv.ptr<f32, Function>
 
   // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
-  "spv.CopyMemory"(%0, %1) {memory_access=0x0001 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  "spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access<Volatile>} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
 
   // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
-  "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  "spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access<Aligned>, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
 
   // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32
-  "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Volatile>, memory_access=#spv.memory_access<Aligned>, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
 
   // CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Aligned", 8] : f32
-  "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+  "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Aligned>, memory_access=#spv.memory_access<Aligned>, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
 
   spv.Return
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 512693e40afef..8c7f6f168c2d4 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -5,8 +5,8 @@
 //===----------------------------------------------------------------------===//
 
 func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
-  // CHECK: %{{.*}} = spv.GroupNonUniformBallot Workgroup %{{.*}}: vector<4xi32>
-  %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
+  // CHECK: %{{.*}} = spv.GroupNonUniformBallot <Workgroup> %{{.*}}: vector<4xi32>
+  %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
   return %0: vector<4xi32>
 }
 
@@ -14,7 +14,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
 
 func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
   // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
-  %0 = spv.GroupNonUniformBallot Device %predicate : vector<4xi32>
+  %0 = spv.GroupNonUniformBallot <Device> %predicate : vector<4xi32>
   return %0: vector<4xi32>
 }
 
@@ -22,7 +22,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
 
 func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
   // expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
-  %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xsi32>
+  %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xsi32>
   return %0: vector<4xsi32>
 }
 
@@ -34,8 +34,8 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
 
 func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
   %one = spv.Constant 1 : i32
-  // CHECK: spv.GroupNonUniformBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32
-  %0 = spv.GroupNonUniformBroadcast Workgroup %value, %one : f32, i32
+  // CHECK: spv.GroupNonUniformBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, i32
+  %0 = spv.GroupNonUniformBroadcast <Workgroup> %value, %one : f32, i32
   return %0: f32
 }
 
@@ -43,8 +43,8 @@ func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
 
 func.func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4xf32> {
   %one = spv.Constant 1 : i32
-  // CHECK: spv.GroupNonUniformBroadcast Subgroup %{{.*}}, %{{.*}} : vector<4xf32>, i32
-  %0 = spv.GroupNonUniformBroadcast Subgroup %value, %one : vector<4xf32>, i32
+  // CHECK: spv.GroupNonUniformBroadcast <Subgroup> %{{.*}}, %{{.*}} : vector<4xf32>, i32
+  %0 = spv.GroupNonUniformBroadcast <Subgroup> %value, %one : vector<4xf32>, i32
   return %0: vector<4xf32>
 }
 
@@ -53,7 +53,7 @@ func.func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4
 func.func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32 ) -> f32 {
   %one = spv.Constant 1 : i32
   // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}} 
-  %0 = spv.GroupNonUniformBroadcast Device %value, %one : f32, i32
+  %0 = spv.GroupNonUniformBroadcast <Device> %value, %one : f32, i32
   return %0: f32
 }
 
@@ -61,7 +61,7 @@ func.func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32
 
 func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid: i32) -> f32 {
   // expected-error @+1 {{id must be the result of a constant op}}
-  %0 = spv.GroupNonUniformBroadcast Subgroup %value, %localid : f32, i32
+  %0 = spv.GroupNonUniformBroadcast <Subgroup> %value, %localid : f32, i32
   return %0: f32
 }
 
@@ -73,8 +73,8 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
 
 // CHECK-LABEL: @group_non_uniform_elect
 func.func @group_non_uniform_elect() -> i1 {
-  // CHECK: %{{.+}} = spv.GroupNonUniformElect Workgroup : i1
-  %0 = spv.GroupNonUniformElect Workgroup : i1
+  // CHECK: %{{.+}} = spv.GroupNonUniformElect <Workgroup> : i1
+  %0 = spv.GroupNonUniformElect <Workgroup> : i1
   return %0: i1
 }
 
@@ -82,7 +82,7 @@ func.func @group_non_uniform_elect() -> i1 {
 
 func.func @group_non_uniform_elect() -> i1 {
   // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
-  %0 = spv.GroupNonUniformElect CrossDevice : i1
+  %0 = spv.GroupNonUniformElect <CrossDevice> : i1
   return %0: i1
 }
 

diff  --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 8c64e4570dece..793f3361854bf 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -819,7 +819,7 @@ spv.module Logical GLSL450 {
     %0 = spv.Variable : !spv.ptr<i32, Function>
 
     // expected-error @+1 {{invalid enclosed op}}
-    %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<i32, Function>) -> i32
+    %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = #spv.memory_access<None>} : (!spv.ptr<i32, Function>) -> i32
     spv.Return
   }
 }

diff  --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 56b67d660601b..cbe390dca35f4 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -165,11 +165,11 @@ func.func @target_env_cooperative_matrix() attributes{
   // CHECK-SAME: #spv.coop_matrix_props<
   // CHECK-SAME:   m_size = 8, n_size = 8, k_size = 32,
   // CHECK-SAME:   a_type = i8, b_type = i8, c_type = i32,
-  // CHECK-SAME:   result_type = i32, scope = 3 : i32>
+  // CHECK-SAME:   result_type = i32, scope = <Subgroup>>
   // CHECK-SAME: #spv.coop_matrix_props<
   // CHECK-SAME:   m_size = 8, n_size = 8, k_size = 16,
   // CHECK-SAME:   a_type = f16, b_type = f16, c_type = f16,
-  // CHECK-SAME:   result_type = f16, scope = 3 : i32>
+  // CHECK-SAME:   result_type = f16, scope = <Subgroup>>
   spv.target_env = #spv.target_env<
   #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class,
                             SPV_NV_cooperative_matrix]>,
@@ -182,7 +182,7 @@ func.func @target_env_cooperative_matrix() attributes{
       b_type = i8,
       c_type = i32,
       result_type = i32,
-      scope = 3 : i32
+      scope = #spv.scope<Subgroup>
     >, #spv.coop_matrix_props<
       m_size = 8,
       n_size = 8,
@@ -191,7 +191,7 @@ func.func @target_env_cooperative_matrix() attributes{
       b_type = f16,
       c_type = f16,
       result_type = f16,
-      scope = 3 : i32
+      scope = #spv.scope<Subgroup>
     >]
   >>
 } { return }

diff  --git a/mlir/test/Dialect/SPIRV/IR/target-env.mlir b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
index e6e5b97c18fb9..8a72b76ef1340 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-env.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
@@ -59,7 +59,7 @@ func.func @cmp_exchange_weak_unsupported_version(%ptr: !spv.ptr<i32, Workgroup>,
 func.func @group_non_uniform_ballot_suitable_version(%predicate: i1) -> vector<4xi32> attributes {
   spv.target_env = #spv.target_env<#spv.vce<v1.4, [GroupNonUniformBallot], []>, #spv.resource_limits<>>
 } {
-  // CHECK: spv.GroupNonUniformBallot Workgroup
+  // CHECK: spv.GroupNonUniformBallot <Workgroup>
   %0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
   return %0: vector<4xi32>
 }

diff  --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index eeb7a0b209457..0c00058842594 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -27,7 +27,7 @@ spv.module Logical GLSL450 attributes {
     #spv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spv.resource_limits<>>
 } {
   spv.func @group_non_uniform_ballot(%predicate : i1) -> vector<4xi32> "None" {
-    %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
+    %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
     spv.ReturnValue %0: vector<4xi32>
   }
 }

diff  --git a/mlir/test/Target/SPIRV/barrier-ops.mlir b/mlir/test/Target/SPIRV/barrier-ops.mlir
index a3b80e06a35e3..d3700f0c92dfc 100644
--- a/mlir/test/Target/SPIRV/barrier-ops.mlir
+++ b/mlir/test/Target/SPIRV/barrier-ops.mlir
@@ -2,23 +2,23 @@
 
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   spv.func @memory_barrier_0() -> () "None" {
-    // CHECK: spv.MemoryBarrier Device, "Release|UniformMemory"
-    spv.MemoryBarrier Device, "Release|UniformMemory"
+    // CHECK: spv.MemoryBarrier <Device>, <Release|UniformMemory>
+    spv.MemoryBarrier <Device>, <Release|UniformMemory>
     spv.Return
   }
   spv.func @memory_barrier_1() -> () "None" {
-    // CHECK: spv.MemoryBarrier Subgroup, "AcquireRelease|SubgroupMemory"
-    spv.MemoryBarrier Subgroup, "AcquireRelease|SubgroupMemory"
+    // CHECK: spv.MemoryBarrier <Subgroup>, <AcquireRelease|SubgroupMemory>
+    spv.MemoryBarrier <Subgroup>, <AcquireRelease|SubgroupMemory>
     spv.Return
   }
   spv.func @control_barrier_0() -> () "None" {
-    // CHECK: spv.ControlBarrier Device, Workgroup, "Release|UniformMemory"
-    spv.ControlBarrier Device, Workgroup, "Release|UniformMemory"
+    // CHECK: spv.ControlBarrier <Device>, <Workgroup>, <Release|UniformMemory>
+    spv.ControlBarrier <Device>, <Workgroup>, <Release|UniformMemory>
     spv.Return
   }
   spv.func @control_barrier_1() -> () "None" {
-    // CHECK: spv.ControlBarrier Workgroup, Invocation, "AcquireRelease|UniformMemory"
-    spv.ControlBarrier Workgroup, Invocation, "AcquireRelease|UniformMemory"
+    // CHECK: spv.ControlBarrier <Workgroup>, <Invocation>, <AcquireRelease|UniformMemory>
+    spv.ControlBarrier <Workgroup>, <Invocation>, <AcquireRelease|UniformMemory>
     spv.Return
   }
 }

diff  --git a/mlir/test/Target/SPIRV/group-ops.mlir b/mlir/test/Target/SPIRV/group-ops.mlir
index 6442e00492e00..27d917b3d49ac 100644
--- a/mlir/test/Target/SPIRV/group-ops.mlir
+++ b/mlir/test/Target/SPIRV/group-ops.mlir
@@ -9,14 +9,14 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   }
   // CHECK-LABEL: @group_broadcast_1
   spv.func @group_broadcast_1(%value: f32, %localid: i32 ) -> f32 "None" {
-    // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32
-    %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, i32
+    // CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, i32
+    %0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, i32
     spv.ReturnValue %0: f32
   }
   // CHECK-LABEL: @group_broadcast_2
   spv.func @group_broadcast_2(%value: f32, %localid: vector<3xi32> ) -> f32 "None" {
-    // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, vector<3xi32>
-    %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, vector<3xi32>
+    // CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, vector<3xi32>
+    %0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, vector<3xi32>
     spv.ReturnValue %0: f32
   }
   // CHECK-LABEL: @subgroup_block_read_intel

diff  --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index e0d576f1b52cf..d429b20e378bd 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -3,23 +3,23 @@
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK-LABEL: @group_non_uniform_ballot
   spv.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> "None" {
-    // CHECK: %{{.*}} = spv.GroupNonUniformBallot Workgroup %{{.*}}: vector<4xi32>
-  %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
+    // CHECK: %{{.*}} = spv.GroupNonUniformBallot <Workgroup> %{{.*}}: vector<4xi32>
+  %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
     spv.ReturnValue %0: vector<4xi32>
   }
 
   // CHECK-LABEL: @group_non_uniform_broadcast
   spv.func @group_non_uniform_broadcast(%value: f32) -> f32 "None" {
     %one = spv.Constant 1 : i32
-    // CHECK: spv.GroupNonUniformBroadcast Subgroup %{{.*}}, %{{.*}} : f32, i32
-    %0 = spv.GroupNonUniformBroadcast Subgroup %value, %one : f32, i32
+    // CHECK: spv.GroupNonUniformBroadcast <Subgroup> %{{.*}}, %{{.*}} : f32, i32
+    %0 = spv.GroupNonUniformBroadcast <Subgroup> %value, %one : f32, i32
     spv.ReturnValue %0: f32
   }
 
   // CHECK-LABEL: @group_non_uniform_elect
   spv.func @group_non_uniform_elect() -> i1 "None" {
-    // CHECK: %{{.+}} = spv.GroupNonUniformElect Workgroup : i1
-    %0 = spv.GroupNonUniformElect Workgroup : i1
+    // CHECK: %{{.+}} = spv.GroupNonUniformElect <Workgroup> : i1
+    %0 = spv.GroupNonUniformElect <Workgroup> : i1
     spv.ReturnValue %0: i1
   }
 

diff  --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index b1fc3c177bcbf..1db8c96bdc99c 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -519,10 +519,24 @@ static void emitAttributeSerialization(const Attribute &attr,
      << formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
   if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
       attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
+    // These two enums are encoded as <id> to constant values in SPIR-V blob,
+    // but we directly use the constant value as attribute in SPIR-V dialect. So
+    // need to handle them separately from normal enum attributes.
+    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("  {0}.push_back(prepareConstantInt({1}.getLoc(), "
-                  "attr.cast<IntegerAttr>()));\n",
-                  operandList, opVar);
+                  "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
+                  "attr.cast<{2}::{3}Attr>().getValue()))));\n",
+                  operandList, opVar, baseEnum.getCppNamespace(),
+                  baseEnum.getEnumClassName());
+  } else if (attr.isSubClassOf("SPV_BitEnumAttr") ||
+             attr.isSubClassOf("SPV_I32EnumAttr")) {
+    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    os << tabs
+       << formatv("  {0}.push_back(static_cast<uint32_t>("
+                  "attr.cast<{1}::{2}Attr>().getValue()));\n",
+                  operandList, baseEnum.getCppNamespace(),
+                  baseEnum.getEnumClassName());
   } else if (attr.getAttrDefName() == "I32ArrayAttr") {
     // Serialize all the elements of the array
     os << tabs << "  for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
@@ -531,7 +545,7 @@ static void emitAttributeSerialization(const Attribute &attr,
                   "attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n",
                   operandList);
     os << tabs << "  }\n";
-  } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
+  } else if (attr.getAttrDefName() == "I32Attr") {
     os << tabs
        << formatv("  {0}.push_back(static_cast<uint32_t>("
                   "attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
@@ -797,10 +811,25 @@ static void emitAttributeDeserialization(const Attribute &attr,
                                          raw_ostream &os) {
   if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
       attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
+    // These two enums are encoded as <id> to constant values in SPIR-V blob,
+    // but we directly use the constant value as attribute in SPIR-V dialect. So
+    // need to handle them separately from normal enum attributes.
+    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
     os << tabs
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
-                  "getConstantInt({2}[{3}++])));\n",
-                  attrList, attrName, words, wordIndex);
+                  "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
+                  "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
+                  attrList, attrName, baseEnum.getCppNamespace(),
+                  baseEnum.getEnumClassName(), words, wordIndex);
+  } else if (attr.isSubClassOf("SPV_BitEnumAttr") ||
+             attr.isSubClassOf("SPV_I32EnumAttr")) {
+    EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+    os << tabs
+       << formatv("  {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
+                  "opBuilder.getAttr<{2}::{3}Attr>("
+                  "static_cast<{2}::{3}>({4}[{5}++]))));\n",
+                  attrList, attrName, baseEnum.getCppNamespace(),
+                  baseEnum.getEnumClassName(), words, wordIndex);
   } else if (attr.getAttrDefName() == "I32ArrayAttr") {
     os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
     os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
@@ -815,7 +844,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "opBuilder.getArrayAttr(attrListElems)));\n",
                   attrList, attrName);
-  } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
+  } else if (attr.getAttrDefName() == "I32Attr") {
     os << tabs
        << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
                   "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
@@ -1257,11 +1286,12 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
   for (const Availability &avail : opAvailabilities)
     availClasses.try_emplace(avail.getClass(), avail);
   for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
-    const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
-    if (!enumAttr)
+    if (!namedAttr.attr.isSubClassOf("SPV_BitEnumAttr") &&
+        !namedAttr.attr.isSubClassOf("SPV_I32EnumAttr"))
       continue;
+    EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
 
-    for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
+    for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
       for (const Availability &caseAvail :
            getAvailabilities(enumerant.getDef()))
         availClasses.try_emplace(caseAvail.getClass(), caseAvail);
@@ -1298,16 +1328,17 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
 
     // Update with enum attributes' specific availability spec.
     for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
-      const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
-      if (!enumAttr)
+      if (!namedAttr.attr.isSubClassOf("SPV_BitEnumAttr") &&
+          !namedAttr.attr.isSubClassOf("SPV_I32EnumAttr"))
         continue;
+      EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
 
       // (enumerant, availability specification) pairs for this availability
       // class.
       SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
 
       // Collect all cases' availability specs.
-      for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
+      for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
         for (const Availability &caseAvail :
              getAvailabilities(enumerant.getDef()))
           if (availClassName == caseAvail.getClass())
@@ -1318,19 +1349,19 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
       if (caseSpecs.empty())
         continue;
 
-      if (enumAttr->isBitEnum()) {
+      if (enumAttr.isBitEnum()) {
         // For BitEnumAttr, we need to iterate over each bit to query its
         // availability spec.
         os << formatv("  for (unsigned i = 0; "
                       "i < std::numeric_limits<{0}>::digits; ++i) {{\n",
-                      enumAttr->getUnderlyingType());
+                      enumAttr.getUnderlyingType());
         os << formatv("    {0}::{1} tblgen_attrVal = this->{2}() & "
                       "static_cast<{0}::{1}>(1 << i);\n",
-                      enumAttr->getCppNamespace(), enumAttr->getEnumClassName(),
+                      enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
                       namedAttr.name);
         os << formatv(
             "    if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
-            enumAttr->getUnderlyingType());
+            enumAttr.getUnderlyingType());
       } else {
         // For IntEnumAttr, we just need to query the value as a whole.
         os << "  {\n";
@@ -1338,7 +1369,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
                       namedAttr.name);
       }
       os << formatv("    auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
-                    enumAttr->getCppNamespace(), avail.getQueryFnName());
+                    enumAttr.getCppNamespace(), avail.getQueryFnName());
       os << "    if (tblgen_instance) "
          // TODO` here once ODS supports
          // dialect-specific contents so that we can use not implementing the
@@ -1385,7 +1416,8 @@ static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
                                       raw_ostream &os) {
   llvm::emitSourceFileHeader("SPIR-V Capability Implication", os);
 
-  EnumAttr enumAttr(recordKeeper.getDef("SPV_CapabilityAttr"));
+  EnumAttr enumAttr(
+      recordKeeper.getDef("SPV_CapabilityAttr")->getValueAsDef("enum"));
 
   os << "ArrayRef<spirv::Capability> "
         "spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"

diff  --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index f17bc53c5bd8b..f7a1db0749c57 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Target/SPIRV/Serialization.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
@@ -46,11 +47,10 @@ class SerializationTest : public ::testing::Test {
     OperationState state(UnknownLoc::get(&context),
                          spirv::ModuleOp::getOperationName());
     state.addAttribute("addressing_model",
-                       builder.getI32IntegerAttr(static_cast<uint32_t>(
-                           spirv::AddressingModel::Logical)));
-    state.addAttribute("memory_model",
-                       builder.getI32IntegerAttr(
-                           static_cast<uint32_t>(spirv::MemoryModel::GLSL450)));
+                       builder.getAttr<spirv::AddressingModelAttr>(
+                           spirv::AddressingModel::Logical));
+    state.addAttribute("memory_model", builder.getAttr<spirv::MemoryModelAttr>(
+                                           spirv::MemoryModel::GLSL450));
     state.addAttribute("vce_triple",
                        spirv::VerCapExtAttr::get(
                            spirv::Version::V_1_0, ArrayRef<spirv::Capability>(),

diff  --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index 7b0a61c2379b9..c0b145e17cedb 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -437,10 +437,13 @@ def get_case_symbol(kind_name, case_name):
   # Generate the enum attribute definition
   kind_category = 'Bit' if is_bit_enum else 'I32'
   enum_attr = '''def SPV_{name}Attr :
-    SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", [
+    SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [
 {cases}
     ]>;'''.format(
-          name=kind_name, category=kind_category, cases=case_names)
+          name=kind_name,
+          snake_name=snake_casify(kind_name),
+          category=kind_category,
+          cases=case_names)
   return kind_name, case_defs + '\n\n' + enum_attr
 
 
@@ -473,7 +476,8 @@ def gen_opcode(instructions):
   ]
   opcode_list = ',\n'.join(opcode_list)
   enum_attr = 'def SPV_OpcodeAttr :\n'\
-              '    SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\
+              '    SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '\
+              '"opcode", [\n'\
               '{lst}\n'\
               '    ]>;'.format(name='Opcode', lst=opcode_list)
   return opcode_str + '\n\n' + enum_attr
@@ -630,9 +634,7 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
 
 def snake_casify(name):
   """Turns the given name to follow snake_case convention."""
-  name = re.sub('\W+', '', name).split()
-  name = [s.lower() for s in name]
-  return '_'.join(name)
+  return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
 
 
 def map_spec_operand_to_ods_argument(operand):


        


More information about the Mlir-commits mailing list