[Mlir-commits] [mlir] a29fffc - [mlir][spirv] Migrate to use specalized enum attributes
Lei Zhang
llvmlistbot at llvm.org
Tue Aug 9 11:15:14 PDT 2022
Author: Lei Zhang
Date: 2022-08-09T14:14:54-04:00
New Revision: a29fffc4752a10494a74bb9f8db9b885c8aa5af1
URL: https://github.com/llvm/llvm-project/commit/a29fffc4752a10494a74bb9f8db9b885c8aa5af1
DIFF: https://github.com/llvm/llvm-project/commit/a29fffc4752a10494a74bb9f8db9b885c8aa5af1.diff
LOG: [mlir][spirv] Migrate to use specalized enum attributes
Previously we are using IntegerAttr to back all SPIR-V enum
attributes. Therefore we all such attributes are showed like
IntegerAttr in IRs, which is barely readable and breaks
roundtripability of the IR. This commit changes to use
`EnumAttr` as the base directly so that we can have separate
attribute definitions and better IR printing.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D131311
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/TableGen/Attribute.h
mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/TableGen/Attribute.cpp
mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/test/Conversion/GPUToSPIRV/simple.mlir
mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
mlir/test/Dialect/SPIRV/IR/availability.mlir
mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir
mlir/test/Dialect/SPIRV/IR/group-ops.mlir
mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
mlir/test/Dialect/SPIRV/IR/target-env.mlir
mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
mlir/test/Target/SPIRV/barrier-ops.mlir
mlir/test/Target/SPIRV/group-ops.mlir
mlir/test/Target/SPIRV/non-uniform-ops.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
mlir/utils/spirv/gen_spirv_dialect.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
index b4ac754c7bfbc..82310cf00bb57 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAvailability.td
@@ -56,7 +56,7 @@ class Availability {
string instance = ?;
}
-class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
+class MinVersionBase<string name, EnumAttr scheme, I32EnumAttrCase min>
: Availability {
let interfaceName = name;
@@ -69,13 +69,13 @@ class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
"std::max(*$overall, $instance)); "
"} else { $overall = $instance; }}";
let initializer = "::llvm::None";
- let instanceType = scheme.cppNamespace # "::" # scheme.className;
+ let instanceType = scheme.cppNamespace # "::" # scheme.enum.className;
- let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
+ let instance = scheme.cppNamespace # "::" # scheme.enum.className # "::" #
min.symbol;
}
-class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
+class MaxVersionBase<string name, EnumAttr scheme, I32EnumAttrCase max>
: Availability {
let interfaceName = name;
@@ -88,9 +88,9 @@ class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
"std::min(*$overall, $instance)); "
"} else { $overall = $instance; }}";
let initializer = "::llvm::None";
- let instanceType = scheme.cppNamespace # "::" # scheme.className;
+ let instanceType = scheme.cppNamespace # "::" # scheme.enum.className;
- let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
+ let instance = scheme.cppNamespace # "::" # scheme.enum.className # "::" #
max.symbol;
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
index e00d8fa0c842c..09e547df4705e 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
@@ -77,8 +77,6 @@ def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> {
let results = (outs);
- let autogenSerialization = 0;
-
let assemblyFormat = [{
$execution_scope `,` $memory_scope `,` $memory_semantics attr-dict
}];
@@ -129,8 +127,6 @@ def SPV_MemoryBarrierOp : SPV_Op<"MemoryBarrier", []> {
let results = (outs);
- let autogenSerialization = 0;
-
let assemblyFormat = "$memory_scope `,` $memory_semantics attr-dict";
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 394ce1e917dff..0cfb3ead3642c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -82,43 +82,31 @@ def SPIRV_Dialect : Dialect {
// Utility definitions
//===----------------------------------------------------------------------===//
-// A predicate that checks whether `$_self` is a known enum case for the
-// enum class with `name`.
-class SPV_IsKnownEnumCaseFor<string name> :
- CPred<"::mlir::spirv::symbolize" # name # "("
- "$_self.cast<IntegerAttr>().getValue().getZExtValue()).has_value()">;
-
// Wrapper over base BitEnumAttr to set common fields.
-class SPV_BitEnumAttr<string name, string description,
- list<BitEnumAttrCaseBase> cases> :
- I32BitEnumAttr<name, description, cases> {
- let predicate = And<[
- I32Attr.predicate,
- SPV_IsKnownEnumCaseFor<name>,
- ]>;
+class SPV_BitEnum<string name, string description,
+ list<BitEnumAttrCaseBase> cases>
+ : I32BitEnumAttr<name, description, cases> {
+ let genSpecializedAttr = 0;
let cppNamespace = "::mlir::spirv";
}
-
-// Wrapper over base I32EnumAttr to set common fields.
-class SPV_I32EnumAttr<string name, string description,
- list<I32EnumAttrCase> cases> :
- I32EnumAttr<name, description, cases> {
- let predicate = And<[
- I32Attr.predicate,
- SPV_IsKnownEnumCaseFor<name>,
- ]>;
- let cppNamespace = "::mlir::spirv";
+class SPV_BitEnumAttr<string name, string description, string mnemonic,
+ list<BitEnumAttrCaseBase> cases> :
+ EnumAttr<SPIRV_Dialect, SPV_BitEnum<name, description, cases>, mnemonic> {
+ let assemblyFormat = "`<` $value `>`";
}
// Wrapper over base I32EnumAttr to set common fields.
-class SPV_Enum<string name, string description, list<I32EnumAttrCase> cases>
+class SPV_I32Enum<string name, string description,
+ list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::spirv";
}
-class SPV_EnumAttr<string name, string description, string mnemonic,
+class SPV_I32EnumAttr<string name, string description, string mnemonic,
list<I32EnumAttrCase> cases> :
- EnumAttr<SPIRV_Dialect, SPV_Enum<name, description, cases>, mnemonic>;
+ EnumAttr<SPIRV_Dialect, SPV_I32Enum<name, description, cases>, mnemonic> {
+ let assemblyFormat = "`<` $value `>`";
+}
//===----------------------------------------------------------------------===//
// SPIR-V availability definitions
@@ -132,7 +120,8 @@ def SPV_V_1_4 : I32EnumAttrCase<"V_1_4", 4, "v1.4">;
def SPV_V_1_5 : I32EnumAttrCase<"V_1_5", 5, "v1.5">;
def SPV_V_1_6 : I32EnumAttrCase<"V_1_6", 6, "v1.6">;
-def SPV_VersionAttr : SPV_I32EnumAttr<"Version", "valid SPIR-V version", [
+def SPV_VersionAttr : SPV_I32EnumAttr<
+ "Version", "valid SPIR-V version", "version", [
SPV_V_1_0, SPV_V_1_1, SPV_V_1_2, SPV_V_1_3, SPV_V_1_4, SPV_V_1_5,
SPV_V_1_6]>;
@@ -284,7 +273,7 @@ def SPV_DT_Other : I32EnumAttrCase<"Other", 3>;
// Information missing.
def SPV_DT_Unknown : I32EnumAttrCase<"Unknown", 4>;
-def SPV_DeviceTypeAttr : SPV_EnumAttr<
+def SPV_DeviceTypeAttr : SPV_I32EnumAttr<
"DeviceType", "valid SPIR-V device types", "device_type", [
SPV_DT_Other, SPV_DT_IntegratedGPU, SPV_DT_DiscreteGPU,
SPV_DT_CPU, SPV_DT_Unknown
@@ -300,7 +289,7 @@ def SPV_V_Qualcomm : I32EnumAttrCase<"Qualcomm", 6>;
def SPV_V_SwiftShader : I32EnumAttrCase<"SwiftShader", 7>;
def SPV_V_Unknown : I32EnumAttrCase<"Unknown", 0xff>;
-def SPV_VendorAttr : SPV_EnumAttr<
+def SPV_VendorAttr : SPV_I32EnumAttr<
"Vendor", "recognized SPIR-V vendor strings", "vendor", [
SPV_V_AMD, SPV_V_Apple, SPV_V_ARM, SPV_V_Imagination,
SPV_V_Intel, SPV_V_NVIDIA, SPV_V_Qualcomm, SPV_V_SwiftShader,
@@ -418,7 +407,7 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
def SPV_ExtensionAttr :
- SPV_EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
+ SPV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
SPV_KHR_float_controls, SPV_KHR_physical_storage_buffer, SPV_KHR_multiview,
SPV_KHR_no_integer_wrap_decoration, SPV_KHR_post_depth_coverage,
@@ -1402,7 +1391,7 @@ def SPV_C_ShaderStereoViewNV : I32EnumAttrCase<"ShaderS
}
def SPV_CapabilityAttr :
- SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", [
+ SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16,
SPV_C_Float64, SPV_C_Int64, SPV_C_Groups, SPV_C_Int16, SPV_C_Int8,
SPV_C_Sampled1D, SPV_C_SampledBuffer, SPV_C_GroupNonUniform, SPV_C_ShaderLayer,
@@ -1514,7 +1503,7 @@ def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64",
}
def SPV_AddressingModelAttr :
- SPV_I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
+ SPV_I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", "addressing_model", [
SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
SPV_AM_PhysicalStorageBuffer64
]>;
@@ -2049,7 +2038,7 @@ def SPV_BI_CullMaskKHR : I32EnumAttrCase<"CullMaskKHR", 6021> {
}
def SPV_BuiltInAttr :
- SPV_I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [
+ SPV_I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", "built_in", [
SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance,
SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId,
SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter,
@@ -2610,7 +2599,7 @@ def SPV_D_MediaBlockIOINTEL : I32EnumAttrCase<"MediaBlockIOINTE
}
def SPV_DecorationAttr :
- SPV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [
+ SPV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", "decoration", [
SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock,
SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride,
SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn,
@@ -2679,7 +2668,7 @@ def SPV_D_SubpassData : I32EnumAttrCase<"SubpassData", 6> {
}
def SPV_DimAttr :
- SPV_I32EnumAttr<"Dim", "valid SPIR-V Dim", [
+ SPV_I32EnumAttr<"Dim", "valid SPIR-V Dim", "dim", [
SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer,
SPV_D_SubpassData
]>;
@@ -3093,7 +3082,7 @@ def SPV_EM_NamedBarrierCountINTEL : I32EnumAttrCase<"NamedBarrierCount
}
def SPV_ExecutionModeAttr :
- SPV_I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", [
+ SPV_I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", "execution_mode", [
SPV_EM_Invocations, SPV_EM_SpacingEqual, SPV_EM_SpacingFractionalEven,
SPV_EM_SpacingFractionalOdd, SPV_EM_VertexOrderCw, SPV_EM_VertexOrderCcw,
SPV_EM_PixelCenterInteger, SPV_EM_OriginUpperLeft, SPV_EM_OriginLowerLeft,
@@ -3203,7 +3192,7 @@ def SPV_EM_CallableKHR : I32EnumAttrCase<"CallableKHR", 5318> {
}
def SPV_ExecutionModelAttr :
- SPV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", [
+ SPV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", "execution_model", [
SPV_EM_Vertex, SPV_EM_TessellationControl, SPV_EM_TessellationEvaluation,
SPV_EM_Geometry, SPV_EM_Fragment, SPV_EM_GLCompute, SPV_EM_Kernel,
SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationKHR, SPV_EM_IntersectionKHR,
@@ -3222,7 +3211,7 @@ def SPV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
}
def SPV_FunctionControlAttr :
- SPV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
+ SPV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const,
SPV_FC_OptNoneINTEL
]>;
@@ -3268,7 +3257,7 @@ def SPV_GO_PartitionedExclusiveScanNV : I32EnumAttrCase<"PartitionedExclusiveSca
}
def SPV_GroupOperationAttr :
- SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", [
+ SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", "group_operation", [
SPV_GO_Reduce, SPV_GO_InclusiveScan, SPV_GO_ExclusiveScan,
SPV_GO_ClusteredReduce, SPV_GO_PartitionedReduceNV,
SPV_GO_PartitionedInclusiveScanNV, SPV_GO_PartitionedExclusiveScanNV
@@ -3482,7 +3471,7 @@ def SPV_IF_R64i : I32EnumAttrCase<"R64i", 41> {
}
def SPV_ImageFormatAttr :
- SPV_I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [
+ SPV_I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", "image_format", [
SPV_IF_Unknown, SPV_IF_Rgba32f, SPV_IF_Rgba16f, SPV_IF_R32f, SPV_IF_Rgba8,
SPV_IF_Rgba8Snorm, SPV_IF_Rg32f, SPV_IF_Rg16f, SPV_IF_R11fG11fB10f,
SPV_IF_R16f, SPV_IF_Rgba16, SPV_IF_Rgb10A2, SPV_IF_Rg16, SPV_IF_Rg8,
@@ -3561,7 +3550,7 @@ def SPV_IO_Nontemporal : I32BitEnumAttrCaseBit<"Nontemporal", 14> {
}
def SPV_ImageOperandsAttr :
- SPV_BitEnumAttr<"ImageOperands", "valid SPIR-V ImageOperands", [
+ SPV_BitEnumAttr<"ImageOperands", "valid SPIR-V ImageOperands", "image_operands", [
SPV_IO_None, SPV_IO_Bias, SPV_IO_Lod, SPV_IO_Grad, SPV_IO_ConstOffset,
SPV_IO_Offset, SPV_IO_ConstOffsets, SPV_IO_Sample, SPV_IO_MinLod,
SPV_IO_MakeTexelAvailable, SPV_IO_MakeTexelVisible, SPV_IO_NonPrivateTexel,
@@ -3587,7 +3576,7 @@ def SPV_LT_LinkOnceODR : I32EnumAttrCase<"LinkOnceODR", 2> {
}
def SPV_LinkageTypeAttr :
- SPV_I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [
+ SPV_I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", "linkage_type", [
SPV_LT_Export, SPV_LT_Import, SPV_LT_LinkOnceODR
]>;
@@ -3679,7 +3668,7 @@ def SPV_LC_NoFusionINTEL : I32BitEnumAttrCaseBit<"NoFusionINTEL", 23
}
def SPV_LoopControlAttr :
- SPV_BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", [
+ SPV_BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", "loop_control", [
SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite,
SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations,
SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount,
@@ -3725,7 +3714,7 @@ def SPV_MA_NoAliasINTELMask : I32BitEnumAttrCaseBit<"NoAliasINTELMask", 17>
}
def SPV_MemoryAccessAttr :
- SPV_BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [
+ SPV_BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", "memory_access", [
SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal,
SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible,
SPV_MA_NonPrivatePointer, SPV_MA_AliasScopeINTELMask, SPV_MA_NoAliasINTELMask
@@ -3754,7 +3743,7 @@ def SPV_MM_Vulkan : I32EnumAttrCase<"Vulkan", 3> {
}
def SPV_MemoryModelAttr :
- SPV_I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
+ SPV_I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", "memory_model", [
SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan
]>;
@@ -3803,7 +3792,7 @@ def SPV_MS_Volatile : I32BitEnumAttrCaseBit<"Volatile", 15> {
}
def SPV_MemorySemanticsAttr :
- SPV_BitEnumAttr<"MemorySemantics", "valid SPIR-V MemorySemantics", [
+ SPV_BitEnumAttr<"MemorySemantics", "valid SPIR-V MemorySemantics", "memory_semantics", [
SPV_MS_None, SPV_MS_Acquire, SPV_MS_Release, SPV_MS_AcquireRelease,
SPV_MS_SequentiallyConsistent, SPV_MS_UniformMemory, SPV_MS_SubgroupMemory,
SPV_MS_WorkgroupMemory, SPV_MS_CrossWorkgroupMemory,
@@ -3829,7 +3818,7 @@ def SPV_S_ShaderCallKHR : I32EnumAttrCase<"ShaderCallKHR", 6> {
}
def SPV_ScopeAttr :
- SPV_I32EnumAttr<"Scope", "valid SPIR-V Scope", [
+ SPV_I32EnumAttr<"Scope", "valid SPIR-V Scope", "scope", [
SPV_S_CrossDevice, SPV_S_Device, SPV_S_Workgroup, SPV_S_Subgroup,
SPV_S_Invocation, SPV_S_QueueFamily, SPV_S_ShaderCallKHR
]>;
@@ -3839,7 +3828,7 @@ def SPV_SC_Flatten : I32BitEnumAttrCaseBit<"Flatten", 0>;
def SPV_SC_DontFlatten : I32BitEnumAttrCaseBit<"DontFlatten", 1>;
def SPV_SelectionControlAttr :
- SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [
+ SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", "selection_control", [
SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten
]>;
@@ -3947,7 +3936,7 @@ def SPV_SC_HostOnlyINTEL : I32EnumAttrCase<"HostOnlyINTEL", 5937> {
}
def SPV_StorageClassAttr :
- SPV_I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", [
+ SPV_I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", "storage_class", [
SPV_SC_UniformConstant, SPV_SC_Input, SPV_SC_Uniform, SPV_SC_Output,
SPV_SC_Workgroup, SPV_SC_CrossWorkgroup, SPV_SC_Private, SPV_SC_Function,
SPV_SC_Generic, SPV_SC_PushConstant, SPV_SC_AtomicCounter, SPV_SC_Image,
@@ -3965,34 +3954,32 @@ def SPV_IDI_NoDepth : I32EnumAttrCase<"NoDepth", 0>;
def SPV_IDI_IsDepth : I32EnumAttrCase<"IsDepth", 1>;
def SPV_IDI_DepthUnknown : I32EnumAttrCase<"DepthUnknown", 2>;
-def SPV_DepthAttr :
- SPV_I32EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification",
- [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]>;
+def SPV_DepthAttr : SPV_I32EnumAttr<
+ "ImageDepthInfo", "valid SPIR-V Image Depth specification",
+ "image_depth_info", [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]>;
def SPV_IAI_NonArrayed : I32EnumAttrCase<"NonArrayed", 0>;
def SPV_IAI_Arrayed : I32EnumAttrCase<"Arrayed", 1>;
-def SPV_ArrayedAttr :
- SPV_I32EnumAttr<
- "ImageArrayedInfo", "valid SPIR-V Image Arrayed specification",
- [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]>;
+def SPV_ArrayedAttr : SPV_I32EnumAttr<
+ "ImageArrayedInfo", "valid SPIR-V Image Arrayed specification",
+ "image_arrayed_info", [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]>;
def SPV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>;
def SPV_ISI_MultiSampled : I32EnumAttrCase<"MultiSampled", 1>;
-def SPV_SamplingAttr:
- SPV_I32EnumAttr<
- "ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
- [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]>;
+def SPV_SamplingAttr: SPV_I32EnumAttr<
+ "ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
+ "image_sampling_info", [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]>;
def SPV_ISUI_SamplerUnknown : I32EnumAttrCase<"SamplerUnknown", 0>;
def SPV_ISUI_NeedSampler : I32EnumAttrCase<"NeedSampler", 1>;
def SPV_ISUI_NoSampler : I32EnumAttrCase<"NoSampler", 2>;
-def SPV_SamplerUseAttr:
- SPV_I32EnumAttr<
- "ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification",
- [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
+def SPV_SamplerUseAttr: SPV_I32EnumAttr<
+ "ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification",
+ "image_sampler_use_info",
+ [SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
@@ -4326,7 +4313,7 @@ def SPV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630
def SPV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
def SPV_OpcodeAttr :
- SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
+ SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource,
SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString,
SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 6b8ceb1771f44..f2098bdb015ee 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -113,6 +113,9 @@ class Attribute : public AttrConstraint {
// Returns the dialect for the attribute if defined.
Dialect getDialect() const;
+
+ // Returns the TableGen definition this Attribute was constructed from.
+ const llvm::Record &getDef() const;
};
// Wrapper class providing helper methods for accessing MLIR constant attribute
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
index 2d8504093e3ce..0ba6708e027b7 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp
@@ -15,8 +15,8 @@
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
-#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringExtras.h"
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 7f6c50c892e87..2ccdfce32cf46 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -15,6 +15,7 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/BuiltinOps.h"
@@ -643,15 +644,15 @@ class ExecutionModePattern
// this entry point's execution mode. We set it to be:
// __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
ModuleOp module = op->getParentOfType<ModuleOp>();
- IntegerAttr executionModeAttr = op.execution_modeAttr();
+ spirv::ExecutionModeAttr executionModeAttr = op.execution_modeAttr();
std::string moduleName;
if (module.getName().has_value())
moduleName = "_" + module.getName().value().str();
else
moduleName = "";
- std::string executionModeInfoName =
- llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName,
- op.fn().str(), executionModeAttr.getValue());
+ std::string executionModeInfoName = llvm::formatv(
+ "__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.fn().str(),
+ static_cast<uint32_t>(executionModeAttr.getValue()));
MLIRContext *context = rewriter.getContext();
OpBuilder::InsertionGuard guard(rewriter);
@@ -684,8 +685,10 @@ class ExecutionModePattern
// Initialize the struct and set the execution mode value.
rewriter.setInsertionPoint(block, block->begin());
Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
- Value executionMode =
- rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
+ Value executionMode = rewriter.create<LLVM::ConstantOp>(
+ loc, llvmI32Type,
+ rewriter.getI32IntegerAttr(
+ static_cast<uint32_t>(executionModeAttr.getValue())));
structValue = rewriter.create<LLVM::InsertValueOp>(
loc, structType, structValue, executionMode,
ArrayAttr::get(context,
@@ -1391,8 +1394,8 @@ class VectorShufflePattern
auto llvmI32Type = IntegerType::get(context, 32);
Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
for (unsigned i = 0; i < componentsArray.size(); i++) {
- if (componentsArray[i].isa<IntegerAttr>())
- op.emitError("unable to support non-constant component");
+ if (!componentsArray[i].isa<IntegerAttr>())
+ return op.emitError("unable to support non-constant component");
int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
if (indexVal == -1)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 9a0771ec6dc05..f7758f904675f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@@ -174,19 +175,16 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
NamedAttrList attr;
auto loc = parser.getCurrentLocation();
if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
- attrName, attr)) {
+ attrName, attr))
return failure();
- }
- if (!attrVal.isa<StringAttr>()) {
+ if (!attrVal.isa<StringAttr>())
return parser.emitError(loc, "expected ")
<< attrName << " attribute specified as string";
- }
auto attrOptional =
spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
- if (!attrOptional) {
+ if (!attrOptional)
return parser.emitError(loc, "invalid ")
<< attrName << " attribute specification: " << attrVal;
- }
value = *attrOptional;
return success();
}
@@ -194,50 +192,52 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
/// Parses the next string attribute in `parser` as an enumerant of the given
/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
/// attribute with the enum class's name as attribute name.
-template <typename EnumClass>
+template <typename EnumAttrClass,
+ typename EnumClass = typename EnumAttrClass::ValueType>
static ParseResult
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
StringRef attrName = spirv::attributeName<EnumClass>()) {
- if (parseEnumStrAttr(value, parser)) {
+ if (parseEnumStrAttr(value, parser))
return failure();
- }
- state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
- llvm::bit_cast<int32_t>(value)));
+ state.addAttribute(attrName,
+ parser.getBuilder().getAttr<EnumAttrClass>(value));
return success();
}
/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
/// and inserts the enumerant into `state` as an 32-bit integer attribute with
/// the enum class's name as attribute name.
-template <typename EnumClass>
+template <typename EnumAttrClass,
+ typename EnumClass = typename EnumAttrClass::ValueType>
static ParseResult
parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
OperationState &state,
StringRef attrName = spirv::attributeName<EnumClass>()) {
- if (parseEnumKeywordAttr(value, parser)) {
+ if (parseEnumKeywordAttr(value, parser))
return failure();
- }
- state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
- llvm::bit_cast<int32_t>(value)));
+ state.addAttribute(attrName,
+ parser.getBuilder().getAttr<EnumAttrClass>(value));
return success();
}
/// Parses Function, Selection and Loop control attributes. If no control is
/// specified, "None" is used as a default.
-template <typename EnumClass>
+template <typename EnumAttrClass, typename EnumClass>
static ParseResult
parseControlAttribute(OpAsmParser &parser, OperationState &state,
StringRef attrName = spirv::attributeName<EnumClass>()) {
if (succeeded(parser.parseOptionalKeyword(kControl))) {
EnumClass control;
- if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
+ if (parser.parseLParen() ||
+ parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
parser.parseRParen())
return failure();
return success();
}
// Set control to "None" otherwise.
Builder builder = parser.getBuilder();
- state.addAttribute(attrName, builder.getI32IntegerAttr(0));
+ state.addAttribute(attrName,
+ builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
return success();
}
@@ -256,10 +256,9 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
- if (parseEnumStrAttr(memoryAccessAttr, parser, state,
- kMemoryAccessAttrName)) {
+ if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
+ kMemoryAccessAttrName))
return failure();
- }
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
@@ -287,10 +286,9 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
- if (parseEnumStrAttr(memoryAccessAttr, parser, state,
- kSourceMemoryAccessAttrName)) {
+ if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
+ kSourceMemoryAccessAttrName))
return failure();
- }
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
@@ -479,15 +477,15 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
return success();
}
- auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
- auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
+ auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
- << memAccessVal;
+ << memAccessAttr;
}
- if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+ if (spirv::bitEnumContains(memAccess.getValue(),
+ spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kAlignmentAttrName)) {
return memoryOp.emitOpError("missing alignment value");
}
@@ -523,15 +521,15 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
return success();
}
- auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
- auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
+ auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
- << memAccessVal;
+ << memAccess;
}
- if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
+ if (spirv::bitEnumContains(memAccess.getValue(),
+ spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kSourceAlignmentAttrName)) {
return memoryOp.emitOpError("missing alignment value");
}
@@ -770,8 +768,10 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
Type type;
SMLoc loc;
- if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
- parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
+ if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
+ kMemoryScopeAttrName) ||
+ parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
+ kSemanticsAttrName) ||
parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
parser.getCurrentLocation(&loc) || parser.parseColonType(type))
return failure();
@@ -793,14 +793,11 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
// Prints an atomic update op.
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
printer << " \"";
- auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
- printer << spirv::stringifyScope(
- static_cast<spirv::Scope>(scopeAttr.getInt()))
- << "\" \"";
- auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
- printer << spirv::stringifyMemorySemantics(
- static_cast<spirv::MemorySemantics>(
- memorySemanticsAttr.getInt()))
+ auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
+ printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
+ auto memorySemanticsAttr =
+ op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
+ printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
<< "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
}
@@ -834,8 +831,9 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
"pointer operand's pointee type ")
<< elementType << ", but found " << valueType;
}
- auto memorySemantics = static_cast<spirv::MemorySemantics>(
- op->getAttrOfType<IntegerAttr>(kSemanticsAttrName).getInt());
+ auto memorySemantics =
+ op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
+ .getValue();
if (failed(verifyMemorySemantics(op, memorySemantics))) {
return failure();
}
@@ -847,10 +845,10 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
spirv::Scope executionScope;
spirv::GroupOperation groupOperation;
OpAsmParser::UnresolvedOperand valueInfo;
- if (parseEnumStrAttr(executionScope, parser, state,
- kExecutionScopeAttrName) ||
- parseEnumStrAttr(groupOperation, parser, state,
- kGroupOperationAttrName) ||
+ if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
+ kExecutionScopeAttrName) ||
+ parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
+ kGroupOperationAttrName) ||
parser.parseOperand(valueInfo))
return failure();
@@ -880,15 +878,17 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
OpAsmPrinter &printer) {
- printer << " \""
- << stringifyScope(static_cast<spirv::Scope>(
- groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
- .getInt()))
- << "\" \""
- << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
- groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
- .getInt()))
- << "\" " << groupOp->getOperand(0);
+ printer
+ << " \""
+ << stringifyScope(
+ groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ .getValue())
+ << "\" \""
+ << stringifyGroupOperation(groupOp
+ ->getAttrOfType<spirv::GroupOperationAttr>(
+ kGroupOperationAttrName)
+ .getValue())
+ << "\" " << groupOp->getOperand(0);
if (groupOp->getNumOperands() > 1)
printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
@@ -896,14 +896,16 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
}
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
- spirv::Scope scope = static_cast<spirv::Scope>(
- groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
+ spirv::Scope scope =
+ groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
+ .getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
- spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
- groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
+ spirv::GroupOperation operation =
+ groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
+ .getValue();
if (operation == spirv::GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
return groupOp->emitOpError("cluster size operand must be provided for "
@@ -1145,11 +1147,12 @@ static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
spirv::MemorySemantics equalSemantics, unequalSemantics;
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type type;
- if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
- parseEnumStrAttr(equalSemantics, parser, state,
- kEqualSemanticsAttrName) ||
- parseEnumStrAttr(unequalSemantics, parser, state,
- kUnequalSemanticsAttrName) ||
+ if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
+ kMemoryScopeAttrName) ||
+ parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+ equalSemantics, parser, state, kEqualSemanticsAttrName) ||
+ parseEnumStrAttr<spirv::MemorySemanticsAttr>(
+ unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
parser.parseOperandList(operandInfo, 3))
return failure();
@@ -1267,8 +1270,10 @@ ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
spirv::MemorySemantics semantics;
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
Type type;
- if (parseEnumStrAttr(memoryScope, parser, result, kMemoryScopeAttrName) ||
- parseEnumStrAttr(semantics, parser, result, kSemanticsAttrName) ||
+ if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
+ kMemoryScopeAttrName) ||
+ parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
+ kSemanticsAttrName) ||
parser.parseOperandList(operandInfo, 2))
return failure();
@@ -2075,7 +2080,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
SmallVector<Attribute, 4> interfaceVars;
FlatSymbolRefAttr fn;
- if (parseEnumStrAttr(execModel, parser, result) ||
+ if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
return failure();
}
@@ -2132,7 +2137,7 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
spirv::ExecutionMode execMode;
Attribute fn;
if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
- parseEnumStrAttr(execMode, parser, result)) {
+ parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
return failure();
}
@@ -2220,7 +2225,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the optional function control keyword.
spirv::FunctionControl fnControl;
- if (parseEnumStrAttr(fnControl, parser, result))
+ if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
return failure();
// If additional attributes are present, parse them.
@@ -2308,7 +2313,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
builder.getStringAttr(name));
state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
- builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
+ builder.getAttr<spirv::FunctionControlAttr>(control));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
}
@@ -2997,14 +3002,14 @@ LogicalResult spirv::LoadOp::verify() {
//===----------------------------------------------------------------------===//
void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
- state.addAttribute("loop_control",
- builder.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::LoopControl::None)));
+ state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
+ spirv::LoopControl::None));
state.addRegion();
}
ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
- if (parseControlAttribute<spirv::LoopControl>(parser, result))
+ if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
+ result))
return failure();
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
}
@@ -3195,9 +3200,9 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
Optional<StringRef> name) {
state.addAttribute(
"addressing_model",
- builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
- state.addAttribute("memory_model", builder.getI32IntegerAttr(
- static_cast<int32_t>(memoryModel)));
+ builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
+ state.addAttribute("memory_model",
+ builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
if (vceTriple)
@@ -3219,8 +3224,10 @@ ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
// Parse attributes
spirv::AddressingModel addrModel;
spirv::MemoryModel memoryModel;
- if (::parseEnumKeywordAttr(addrModel, parser, result) ||
- ::parseEnumKeywordAttr(memoryModel, parser, result))
+ if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
+ result) ||
+ ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
+ result))
return failure();
if (succeeded(parser.parseOptionalKeyword("requires"))) {
@@ -3401,7 +3408,8 @@ LogicalResult spirv::SelectOp::verify() {
ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
OperationState &result) {
- if (parseControlAttribute<spirv::SelectionControl>(parser, result))
+ if (parseControlAttribute<spirv::SelectionControlAttr,
+ spirv::SelectionControl>(parser, result))
return failure();
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
}
@@ -3666,8 +3674,8 @@ ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
return failure();
}
- auto attr = parser.getBuilder().getI32IntegerAttr(
- llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
+ auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
+ ptrType.getStorageClass());
result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
return success();
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index a8240a39ffb92..0b2841c296d3f 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -132,6 +132,8 @@ Dialect Attribute::getDialect() const {
return Dialect(nullptr);
}
+const llvm::Record &Attribute::getDef() const { return *def; }
+
ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
"must be subclass of TableGen 'ConstantAttr' class");
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index 9566b9ed1bbe8..c6787d79ffe7b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -12,6 +12,7 @@
#include "Deserializer.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
@@ -406,35 +407,6 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
return success();
}
-template <>
-LogicalResult
-Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
- if (operands.size() != 3) {
- return emitError(
- unknownLoc,
- "OpControlBarrier must have execution scope <id>, memory scope <id> "
- "and memory semantics <id>");
- }
-
- SmallVector<IntegerAttr, 3> argAttrs;
- for (auto operand : operands) {
- auto argAttr = getConstantInt(operand);
- if (!argAttr) {
- return emitError(unknownLoc,
- "expected 32-bit integer constant from <id> ")
- << operand << " for OpControlBarrier";
- }
- argAttrs.push_back(argAttr);
- }
-
- opBuilder.create<spirv::ControlBarrierOp>(
- unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
- argAttrs[1].cast<spirv::ScopeAttr>(),
- argAttrs[2].cast<spirv::MemorySemanticsAttr>());
-
- return success();
-}
-
template <>
LogicalResult
Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
@@ -477,31 +449,6 @@ Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
return success();
}
-template <>
-LogicalResult
-Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
- if (operands.size() != 2) {
- return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
- "and memory semantics <id>");
- }
-
- SmallVector<IntegerAttr, 2> argAttrs;
- for (auto operand : operands) {
- auto argAttr = getConstantInt(operand);
- if (!argAttr) {
- return emitError(unknownLoc,
- "expected 32-bit integer constant from <id> ")
- << operand << " for OpMemoryBarrier";
- }
- argAttrs.push_back(argAttr);
- }
-
- opBuilder.create<spirv::MemoryBarrierOp>(
- unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
- argAttrs[1].cast<spirv::MemorySemanticsAttr>());
- return success();
-}
-
template <>
LogicalResult
Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
@@ -538,8 +485,9 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
if (wordIndex < words.size()) {
auto attrValue = words[wordIndex++];
- attributes.push_back(opBuilder.getNamedAttr(
- "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
+ auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
+ static_cast<spirv::MemoryAccess>(attrValue));
+ attributes.push_back(opBuilder.getNamedAttr("memory_access", attr));
isAlignedAttr = (attrValue == 2);
}
@@ -549,9 +497,10 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
}
if (wordIndex < words.size()) {
- attributes.push_back(opBuilder.getNamedAttr(
- "source_memory_access",
- opBuilder.getI32IntegerAttr(words[wordIndex++])));
+ auto attrValue = words[wordIndex++];
+ auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
+ static_cast<spirv::MemoryAccess>(attrValue));
+ attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
}
if (wordIndex < words.size()) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 383bdc4c905a9..e4cfc4b380e46 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -216,10 +216,11 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
(*module)->setAttr(
"addressing_model",
- opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
- (*module)->setAttr(
- "memory_model",
- opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
+ opBuilder.getAttr<spirv::AddressingModelAttr>(
+ static_cast<spirv::AddressingModel>(operands.front())));
+ (*module)->setAttr("memory_model",
+ opBuilder.getAttr<spirv::MemoryModelAttr>(
+ static_cast<spirv::MemoryModel>(operands.back())));
return success();
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ad1724f9269f0..22fff80440048 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -13,6 +13,7 @@
#include "Serializer.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
@@ -277,8 +278,8 @@ LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
operands.push_back(resultID);
auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
if (attr) {
- operands.push_back(static_cast<uint32_t>(
- attr.cast<IntegerAttr>().getValue().getZExtValue()));
+ operands.push_back(
+ static_cast<uint32_t>(attr.cast<spirv::StorageClassAttr>().getValue()));
}
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
for (auto arg : op.getODSOperands(0)) {
@@ -565,27 +566,6 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
return success();
}
-template <>
-LogicalResult
-Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
- StringRef argNames[] = {"execution_scope", "memory_scope",
- "memory_semantics"};
- SmallVector<uint32_t, 3> operands;
-
- for (auto argName : argNames) {
- auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
- auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
- if (!operand) {
- return failure();
- }
- operands.push_back(operand);
- }
-
- encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
- operands);
- return success();
-}
-
template <>
LogicalResult
Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
@@ -615,25 +595,6 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
return success();
}
-template <>
-LogicalResult
-Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
- StringRef argNames[] = {"memory_scope", "memory_semantics"};
- SmallVector<uint32_t, 2> operands;
-
- for (auto argName : argNames) {
- auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
- auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
- if (!operand) {
- return failure();
- }
- operands.push_back(operand);
- }
-
- encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands);
- return success();
-}
-
template <>
LogicalResult
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
@@ -674,8 +635,8 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
}
if (auto attr = op->getAttr("memory_access")) {
- operands.push_back(static_cast<uint32_t>(
- attr.cast<IntegerAttr>().getValue().getZExtValue()));
+ operands.push_back(
+ static_cast<uint32_t>(attr.cast<spirv::MemoryAccessAttr>().getValue()));
}
elidedAttrs.push_back("memory_access");
@@ -688,8 +649,8 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
elidedAttrs.push_back("alignment");
if (auto attr = op->getAttr("source_memory_access")) {
- operands.push_back(static_cast<uint32_t>(
- attr.cast<IntegerAttr>().getValue().getZExtValue()));
+ operands.push_back(
+ static_cast<uint32_t>(attr.cast<spirv::MemoryAccessAttr>().getValue()));
}
elidedAttrs.push_back("source_memory_access");
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 47542a41bbb71..b8be9433e94e8 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
@@ -23,6 +24,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/Debug.h"
+#include <cstdint>
#define DEBUG_TYPE "spirv-serialization"
@@ -192,8 +194,11 @@ void Serializer::processExtension() {
}
void Serializer::processMemoryModel() {
- uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
- uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
+ auto mm = static_cast<uint32_t>(
+ module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue());
+ auto am = static_cast<uint32_t>(
+ module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model")
+ .getValue());
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
}
diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
index 5d1e9168bf607..c6b45ba491905 100644
--- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir
@@ -112,7 +112,7 @@ module attributes {gpu.container_module} {
// CHECK-LABEL: spv.func @barrier
gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32>) kernel
attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
- // CHECK: spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
+ // CHECK: spv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
gpu.barrier
gpu.return
}
diff --git a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
index 1a7ef8e0b32c8..3b0af88a299e6 100644
--- a/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
+++ b/mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
@@ -32,7 +32,7 @@ module attributes {
// CHECK: %[[ADD:.+]] = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[VAL]] : i32
// CHECK: %[[OUTPTR:.+]] = spv.AccessChain %[[OUTPUT]][%[[ZERO]], %[[ZERO]]]
-// CHECK: %[[ELECT:.+]] = spv.GroupNonUniformElect Subgroup : i1
+// CHECK: %[[ELECT:.+]] = spv.GroupNonUniformElect <Subgroup> : i1
// CHECK: spv.mlir.selection {
// CHECK: spv.BranchConditional %[[ELECT]], ^bb1, ^bb2
diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
index 5f474208ec47f..1e3908618e986 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir
@@ -1,32 +1,30 @@
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=vulkan' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=VULKAN
// Vulkan Mappings:
-// 0 -> StorageBuffer (12)
-// 1 -> Generic (8)
-// 3 -> Workgroup (4)
-// 4 -> Uniform (2)
-// TODO: create a StorageClass wrapper class so we can print the symbolc
-// storage class (instead of the backing IntegerAttr) and be able to
-// round trip the IR.
+// 0 -> StorageBuffer
+// 1 -> Generic
+// 2 -> [null]
+// 3 -> Workgroup
+// 4 -> Uniform
// VULKAN-LABEL: func @operand_result
func.func @operand_result() {
- // VULKAN: memref<f32, 12 : i32>
+ // VULKAN: memref<f32, #spv.storage_class<StorageBuffer>>
%0 = "dialect.memref_producer"() : () -> (memref<f32>)
- // VULKAN: memref<4xi32, 8 : i32>
+ // VULKAN: memref<4xi32, #spv.storage_class<Generic>>
%1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
- // VULKAN: memref<?x4xf16, 4 : i32>
+ // VULKAN: memref<?x4xf16, #spv.storage_class<Workgroup>>
%2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
- // VULKAN: memref<*xf16, 2 : i32>
+ // VULKAN: memref<*xf16, #spv.storage_class<Uniform>>
%3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
"dialect.memref_consumer"(%0) : (memref<f32>) -> ()
- // VULKAN: memref<4xi32, 8 : i32>
+ // VULKAN: memref<4xi32, #spv.storage_class<Generic>>
"dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
- // VULKAN: memref<?x4xf16, 4 : i32>
+ // VULKAN: memref<?x4xf16, #spv.storage_class<Workgroup>>
"dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
- // VULKAN: memref<*xf16, 2 : i32>
+ // VULKAN: memref<*xf16, #spv.storage_class<Uniform>>
"dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
return
@@ -36,7 +34,7 @@ func.func @operand_result() {
// VULKAN-LABEL: func @type_attribute
func.func @type_attribute() {
- // VULKAN: attr = memref<i32, 8 : i32>
+ // VULKAN: attr = memref<i32, #spv.storage_class<Generic>>
"dialect.memref_producer"() { attr = memref<i32, 1> } : () -> ()
return
}
@@ -45,9 +43,9 @@ func.func @type_attribute() {
// VULKAN-LABEL: func @function_io
func.func @function_io
- // VULKAN-SAME: (%{{.+}}: memref<f64, 8 : i32>, %{{.+}}: memref<4xi32, 4 : i32>)
+ // VULKAN-SAME: (%{{.+}}: memref<f64, #spv.storage_class<Generic>>, %{{.+}}: memref<4xi32, #spv.storage_class<Workgroup>>)
(%arg0: memref<f64, 1>, %arg1: memref<4xi32, 3>)
- // VULKAN-SAME: -> (memref<f64, 8 : i32>, memref<4xi32, 4 : i32>)
+ // VULKAN-SAME: -> (memref<f64, #spv.storage_class<Generic>>, memref<4xi32, #spv.storage_class<Workgroup>>)
-> (memref<f64, 1>, memref<4xi32, 3>) {
return %arg0, %arg1: memref<f64, 1>, memref<4xi32, 3>
}
@@ -57,8 +55,8 @@ func.func @function_io
// VULKAN: func @region
func.func @region(%cond: i1, %arg0: memref<f32, 1>) {
scf.if %cond {
- // VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, 4 : i32>}
- // VULKAN-SAME: (memref<f32, 8 : i32>) -> memref<f32, 8 : i32>
+ // VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, #spv.storage_class<Workgroup>>}
+ // VULKAN-SAME: (memref<f32, #spv.storage_class<Generic>>) -> memref<f32, #spv.storage_class<Generic>>
%0 = "dialect.memref_consumer"(%arg0) { attr = memref<i64, 3> } : (memref<f32, 1>) -> (memref<f32, 1>)
}
return
diff --git a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
index ed7e8bc72c8ac..9889422fa31b8 100644
--- a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir
@@ -14,7 +14,7 @@ func.func @atomic_and(%ptr : !spv.ptr<i32, StorageBuffer>, %value : i32) -> i32
func.func @atomic_and(%ptr : !spv.ptr<f32, StorageBuffer>, %value : i32) -> i32 {
// expected-error @+1 {{pointer operand must point to an integer value, found 'f32'}}
- %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4 : i32} : (!spv.ptr<f32, StorageBuffer>, i32) -> (i32)
+ %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<f32, StorageBuffer>, i32) -> (i32)
return %0 : i32
}
@@ -23,7 +23,7 @@ func.func @atomic_and(%ptr : !spv.ptr<f32, StorageBuffer>, %value : i32) -> i32
func.func @atomic_and(%ptr : !spv.ptr<i32, StorageBuffer>, %value : i64) -> i64 {
// expected-error @+1 {{expected value to have the same type as the pointer operand's pointee type 'i32', but found 'i64'}}
- %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = 2: i32, semantics = 0x8 : i32} : (!spv.ptr<i32, StorageBuffer>, i64) -> (i64)
+ %0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, StorageBuffer>, i64) -> (i64)
return %0 : i64
}
@@ -51,7 +51,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32,
func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64, %comparator: i32) -> i32 {
// expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
- %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
+ %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
return %0: i32
}
@@ -59,7 +59,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64,
func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i16) -> i32 {
// expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}}
- %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
+ %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
return %0: i32
}
@@ -67,7 +67,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32,
func.func @atomic_compare_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32, %comparator: i32) -> i32 {
// expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
- %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
+ %0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
return %0: i32
}
@@ -87,7 +87,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value:
func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i64, %comparator: i32) -> i32 {
// expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
- %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
+ %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
return %0: i32
}
@@ -95,7 +95,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value:
func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i16) -> i32 {
// expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}}
- %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
+ %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
return %0: i32
}
@@ -103,7 +103,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value:
func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i64, Workgroup>, %value: i32, %comparator: i32) -> i32 {
// expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
- %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
+ %0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
return %0: i32
}
@@ -123,7 +123,7 @@ func.func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32) -> i32 {
func.func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64) -> i32 {
// expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
- %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i32, Workgroup>, i64) -> (i32)
+ %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i64) -> (i32)
return %0: i32
}
@@ -131,7 +131,7 @@ func.func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64) -> i32 {
func.func @atomic_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32) -> i32 {
// expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
- %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i64, Workgroup>, i32) -> (i32)
+ %0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i64, Workgroup>, i32) -> (i32)
return %0: i32
}
@@ -253,7 +253,7 @@ func.func @atomic_fadd(%ptr : !spv.ptr<f32, StorageBuffer>, %value : f32) -> f32
func.func @atomic_fadd(%ptr : !spv.ptr<i32, StorageBuffer>, %value : f32) -> f32 {
// expected-error @+1 {{pointer operand must point to an float value, found 'i32'}}
- %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4 : i32} : (!spv.ptr<i32, StorageBuffer>, f32) -> (f32)
+ %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, StorageBuffer>, f32) -> (f32)
return %0 : f32
}
@@ -261,7 +261,7 @@ func.func @atomic_fadd(%ptr : !spv.ptr<i32, StorageBuffer>, %value : f32) -> f32
func.func @atomic_fadd(%ptr : !spv.ptr<f32, StorageBuffer>, %value : f64) -> f64 {
// expected-error @+1 {{expected value to have the same type as the pointer operand's pointee type 'f32', but found 'f64'}}
- %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = 2: i32, semantics = 0x8 : i32} : (!spv.ptr<f32, StorageBuffer>, f64) -> (f64)
+ %0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope<Device>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<f32, StorageBuffer>, f64) -> (f64)
return %0 : f64
}
diff --git a/mlir/test/Dialect/SPIRV/IR/availability.mlir b/mlir/test/Dialect/SPIRV/IR/availability.mlir
index 089b70cf949bb..810fe53faa953 100644
--- a/mlir/test/Dialect/SPIRV/IR/availability.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -26,7 +26,7 @@ func.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
// CHECK: max version: v1.6
// CHECK: extensions: [ ]
// CHECK: capabilities: [ [GroupNonUniformBallot] ]
- %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
+ %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
return %0: vector<4xi32>
}
diff --git a/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir b/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir
index 45d0a7430244e..931426f32848b 100644
--- a/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/barrier-ops.mlir
@@ -5,16 +5,17 @@
//===----------------------------------------------------------------------===//
func.func @control_barrier_0() -> () {
- // CHECK: spv.ControlBarrier Workgroup, Device, "Acquire|UniformMemory"
- spv.ControlBarrier Workgroup, Device, "Acquire|UniformMemory"
+ // CHECK: spv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
+ spv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
return
}
// -----
func.func @control_barrier_1() -> () {
- // expected-error @+1 {{expected string or keyword containing one of the following enum values}}
- spv.ControlBarrier Something, Device, "Acquire|UniformMemory"
+ // expected-error @+2 {{to be one of}}
+ // expected-error @+1 {{failed to parse SPV_ScopeAttr}}
+ spv.ControlBarrier <Something>, <Device>, <Acquire|UniformMemory>
return
}
@@ -26,16 +27,16 @@ func.func @control_barrier_1() -> () {
//===----------------------------------------------------------------------===//
func.func @memory_barrier_0() -> () {
- // CHECK: spv.MemoryBarrier Device, "Acquire|UniformMemory"
- spv.MemoryBarrier Device, "Acquire|UniformMemory"
+ // CHECK: spv.MemoryBarrier <Device>, <Acquire|UniformMemory>
+ spv.MemoryBarrier <Device>, <Acquire|UniformMemory>
return
}
// -----
func.func @memory_barrier_1() -> () {
- // CHECK: spv.MemoryBarrier Workgroup, Acquire
- spv.MemoryBarrier Workgroup, Acquire
+ // CHECK: spv.MemoryBarrier <Workgroup>, <Acquire>
+ spv.MemoryBarrier <Workgroup>, <Acquire>
return
}
@@ -43,7 +44,7 @@ func.func @memory_barrier_1() -> () {
func.func @memory_barrier_2() -> () {
// expected-error @+1 {{expected at most one of these four memory constraints to be set: `Acquire`, `Release`,`AcquireRelease` or `SequentiallyConsistent`}}
- spv.MemoryBarrier Device, "Acquire|Release"
+ spv.MemoryBarrier <Device>, <Acquire|Release>
return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
index 103e41016648b..a62f6fffa1616 100644
--- a/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/group-ops.mlir
@@ -17,24 +17,24 @@ func.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
//===----------------------------------------------------------------------===//
func.func @group_broadcast_scalar(%value: f32, %localid: i32 ) -> f32 {
- // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32
- %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, i32
+ // CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, i32
+ %0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, i32
return %0: f32
}
// -----
func.func @group_broadcast_scalar_vector(%value: f32, %localid: vector<3xi32> ) -> f32 {
- // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, vector<3xi32>
- %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, vector<3xi32>
+ // CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, vector<3xi32>
+ %0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, vector<3xi32>
return %0: f32
}
// -----
func.func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32> ) -> vector<4xf32> {
- // CHECK: spv.GroupBroadcast Subgroup %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
- %0 = spv.GroupBroadcast Subgroup %value, %localid : vector<4xf32>, vector<3xi32>
+ // CHECK: spv.GroupBroadcast <Subgroup> %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
+ %0 = spv.GroupBroadcast <Subgroup> %value, %localid : vector<4xf32>, vector<3xi32>
return %0: vector<4xf32>
}
@@ -42,7 +42,7 @@ func.func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32>
func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
- %0 = spv.GroupBroadcast Device %value, %localid : f32, vector<3xi32>
+ %0 = spv.GroupBroadcast <Device> %value, %localid : f32, vector<3xi32>
return %0: f32
}
@@ -50,7 +50,7 @@ func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> )
func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 {
// expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
- %0 = spv.GroupBroadcast Subgroup %value, %localid : f32, vector<3xf32>
+ %0 = spv.GroupBroadcast <Subgroup> %value, %localid : f32, vector<3xf32>
return %0: f32
}
@@ -58,7 +58,7 @@ func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3x
func.func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> ) -> f32 {
// expected-error @+1 {{localid is a vector and can be with only 2 or 3 components, actual number is 4}}
- %0 = spv.GroupBroadcast Subgroup %value, %localid : f32, vector<4xi32>
+ %0 = spv.GroupBroadcast <Subgroup> %value, %localid : f32, vector<4xi32>
return %0: f32
}
diff --git a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
index 58e72570ed497..f2b3979ef5f6f 100644
--- a/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/memory-ops.mlir
@@ -198,7 +198,7 @@ func.func @load_none_access() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["None"]
- %1 = "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<f32, Function>) -> (f32)
+ %1 = "spv.Load"(%0) {memory_access = #spv.memory_access<None>} : (!spv.ptr<f32, Function>) -> (f32)
return
}
@@ -207,7 +207,7 @@ func.func @volatile_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["Volatile"]
- %1 = "spv.Load"(%0) {memory_access = 1 : i32} : (!spv.ptr<f32, Function>) -> (f32)
+ %1 = "spv.Load"(%0) {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<f32, Function>) -> (f32)
return
}
@@ -216,7 +216,7 @@ func.func @aligned_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["Aligned", 4]
- %1 = "spv.Load"(%0) {memory_access = 2 : i32, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
+ %1 = "spv.Load"(%0) {memory_access = #spv.memory_access<Aligned>, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
return
}
@@ -225,7 +225,7 @@ func.func @volatile_aligned_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["Volatile|Aligned", 4]
- %1 = "spv.Load"(%0) {memory_access = 3 : i32, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
+ %1 = "spv.Load"(%0) {memory_access = #spv.memory_access<Volatile|Aligned>, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
return
}
@@ -588,7 +588,7 @@ func.func @copy_memory_invalid_maa() {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{missing alignment value}}
- "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+ "spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access<Aligned>} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}
@@ -598,7 +598,7 @@ func.func @copy_memory_invalid_source_maa() {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{invalid alignment specification with non-aligned memory access specification}}
- "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+ "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Volatile>, memory_access=#spv.memory_access<Aligned>, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}
@@ -608,7 +608,7 @@ func.func @copy_memory_invalid_source_maa2() {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{missing alignment value}}
- "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+ "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Aligned>, memory_access=#spv.memory_access<Aligned>, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}
@@ -619,16 +619,16 @@ func.func @copy_memory_print_maa() {
%1 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
- "spv.CopyMemory"(%0, %1) {memory_access=0x0001 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+ "spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access<Volatile>} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
- "spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+ "spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access<Aligned>, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32
- "spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+ "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Volatile>, memory_access=#spv.memory_access<Aligned>, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Aligned", 8] : f32
- "spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
+ "spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Aligned>, memory_access=#spv.memory_access<Aligned>, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
index 512693e40afef..8c7f6f168c2d4 100644
--- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir
@@ -5,8 +5,8 @@
//===----------------------------------------------------------------------===//
func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
- // CHECK: %{{.*}} = spv.GroupNonUniformBallot Workgroup %{{.*}}: vector<4xi32>
- %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
+ // CHECK: %{{.*}} = spv.GroupNonUniformBallot <Workgroup> %{{.*}}: vector<4xi32>
+ %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
return %0: vector<4xi32>
}
@@ -14,7 +14,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
- %0 = spv.GroupNonUniformBallot Device %predicate : vector<4xi32>
+ %0 = spv.GroupNonUniformBallot <Device> %predicate : vector<4xi32>
return %0: vector<4xi32>
}
@@ -22,7 +22,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
// expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
- %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xsi32>
+ %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xsi32>
return %0: vector<4xsi32>
}
@@ -34,8 +34,8 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
%one = spv.Constant 1 : i32
- // CHECK: spv.GroupNonUniformBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32
- %0 = spv.GroupNonUniformBroadcast Workgroup %value, %one : f32, i32
+ // CHECK: spv.GroupNonUniformBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, i32
+ %0 = spv.GroupNonUniformBroadcast <Workgroup> %value, %one : f32, i32
return %0: f32
}
@@ -43,8 +43,8 @@ func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
func.func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4xf32> {
%one = spv.Constant 1 : i32
- // CHECK: spv.GroupNonUniformBroadcast Subgroup %{{.*}}, %{{.*}} : vector<4xf32>, i32
- %0 = spv.GroupNonUniformBroadcast Subgroup %value, %one : vector<4xf32>, i32
+ // CHECK: spv.GroupNonUniformBroadcast <Subgroup> %{{.*}}, %{{.*}} : vector<4xf32>, i32
+ %0 = spv.GroupNonUniformBroadcast <Subgroup> %value, %one : vector<4xf32>, i32
return %0: vector<4xf32>
}
@@ -53,7 +53,7 @@ func.func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4
func.func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32 ) -> f32 {
%one = spv.Constant 1 : i32
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
- %0 = spv.GroupNonUniformBroadcast Device %value, %one : f32, i32
+ %0 = spv.GroupNonUniformBroadcast <Device> %value, %one : f32, i32
return %0: f32
}
@@ -61,7 +61,7 @@ func.func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32
func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid: i32) -> f32 {
// expected-error @+1 {{id must be the result of a constant op}}
- %0 = spv.GroupNonUniformBroadcast Subgroup %value, %localid : f32, i32
+ %0 = spv.GroupNonUniformBroadcast <Subgroup> %value, %localid : f32, i32
return %0: f32
}
@@ -73,8 +73,8 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
// CHECK-LABEL: @group_non_uniform_elect
func.func @group_non_uniform_elect() -> i1 {
- // CHECK: %{{.+}} = spv.GroupNonUniformElect Workgroup : i1
- %0 = spv.GroupNonUniformElect Workgroup : i1
+ // CHECK: %{{.+}} = spv.GroupNonUniformElect <Workgroup> : i1
+ %0 = spv.GroupNonUniformElect <Workgroup> : i1
return %0: i1
}
@@ -82,7 +82,7 @@ func.func @group_non_uniform_elect() -> i1 {
func.func @group_non_uniform_elect() -> i1 {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
- %0 = spv.GroupNonUniformElect CrossDevice : i1
+ %0 = spv.GroupNonUniformElect <CrossDevice> : i1
return %0: i1
}
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 8c64e4570dece..793f3361854bf 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -819,7 +819,7 @@ spv.module Logical GLSL450 {
%0 = spv.Variable : !spv.ptr<i32, Function>
// expected-error @+1 {{invalid enclosed op}}
- %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<i32, Function>) -> i32
+ %1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = #spv.memory_access<None>} : (!spv.ptr<i32, Function>) -> i32
spv.Return
}
}
diff --git a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
index 56b67d660601b..cbe390dca35f4 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-and-abi.mlir
@@ -165,11 +165,11 @@ func.func @target_env_cooperative_matrix() attributes{
// CHECK-SAME: #spv.coop_matrix_props<
// CHECK-SAME: m_size = 8, n_size = 8, k_size = 32,
// CHECK-SAME: a_type = i8, b_type = i8, c_type = i32,
- // CHECK-SAME: result_type = i32, scope = 3 : i32>
+ // CHECK-SAME: result_type = i32, scope = <Subgroup>>
// CHECK-SAME: #spv.coop_matrix_props<
// CHECK-SAME: m_size = 8, n_size = 8, k_size = 16,
// CHECK-SAME: a_type = f16, b_type = f16, c_type = f16,
- // CHECK-SAME: result_type = f16, scope = 3 : i32>
+ // CHECK-SAME: result_type = f16, scope = <Subgroup>>
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class,
SPV_NV_cooperative_matrix]>,
@@ -182,7 +182,7 @@ func.func @target_env_cooperative_matrix() attributes{
b_type = i8,
c_type = i32,
result_type = i32,
- scope = 3 : i32
+ scope = #spv.scope<Subgroup>
>, #spv.coop_matrix_props<
m_size = 8,
n_size = 8,
@@ -191,7 +191,7 @@ func.func @target_env_cooperative_matrix() attributes{
b_type = f16,
c_type = f16,
result_type = f16,
- scope = 3 : i32
+ scope = #spv.scope<Subgroup>
>]
>>
} { return }
diff --git a/mlir/test/Dialect/SPIRV/IR/target-env.mlir b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
index e6e5b97c18fb9..8a72b76ef1340 100644
--- a/mlir/test/Dialect/SPIRV/IR/target-env.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/target-env.mlir
@@ -59,7 +59,7 @@ func.func @cmp_exchange_weak_unsupported_version(%ptr: !spv.ptr<i32, Workgroup>,
func.func @group_non_uniform_ballot_suitable_version(%predicate: i1) -> vector<4xi32> attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.4, [GroupNonUniformBallot], []>, #spv.resource_limits<>>
} {
- // CHECK: spv.GroupNonUniformBallot Workgroup
+ // CHECK: spv.GroupNonUniformBallot <Workgroup>
%0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
return %0: vector<4xi32>
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
index eeb7a0b209457..0c00058842594 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir
@@ -27,7 +27,7 @@ spv.module Logical GLSL450 attributes {
#spv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spv.resource_limits<>>
} {
spv.func @group_non_uniform_ballot(%predicate : i1) -> vector<4xi32> "None" {
- %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
+ %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
}
diff --git a/mlir/test/Target/SPIRV/barrier-ops.mlir b/mlir/test/Target/SPIRV/barrier-ops.mlir
index a3b80e06a35e3..d3700f0c92dfc 100644
--- a/mlir/test/Target/SPIRV/barrier-ops.mlir
+++ b/mlir/test/Target/SPIRV/barrier-ops.mlir
@@ -2,23 +2,23 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.func @memory_barrier_0() -> () "None" {
- // CHECK: spv.MemoryBarrier Device, "Release|UniformMemory"
- spv.MemoryBarrier Device, "Release|UniformMemory"
+ // CHECK: spv.MemoryBarrier <Device>, <Release|UniformMemory>
+ spv.MemoryBarrier <Device>, <Release|UniformMemory>
spv.Return
}
spv.func @memory_barrier_1() -> () "None" {
- // CHECK: spv.MemoryBarrier Subgroup, "AcquireRelease|SubgroupMemory"
- spv.MemoryBarrier Subgroup, "AcquireRelease|SubgroupMemory"
+ // CHECK: spv.MemoryBarrier <Subgroup>, <AcquireRelease|SubgroupMemory>
+ spv.MemoryBarrier <Subgroup>, <AcquireRelease|SubgroupMemory>
spv.Return
}
spv.func @control_barrier_0() -> () "None" {
- // CHECK: spv.ControlBarrier Device, Workgroup, "Release|UniformMemory"
- spv.ControlBarrier Device, Workgroup, "Release|UniformMemory"
+ // CHECK: spv.ControlBarrier <Device>, <Workgroup>, <Release|UniformMemory>
+ spv.ControlBarrier <Device>, <Workgroup>, <Release|UniformMemory>
spv.Return
}
spv.func @control_barrier_1() -> () "None" {
- // CHECK: spv.ControlBarrier Workgroup, Invocation, "AcquireRelease|UniformMemory"
- spv.ControlBarrier Workgroup, Invocation, "AcquireRelease|UniformMemory"
+ // CHECK: spv.ControlBarrier <Workgroup>, <Invocation>, <AcquireRelease|UniformMemory>
+ spv.ControlBarrier <Workgroup>, <Invocation>, <AcquireRelease|UniformMemory>
spv.Return
}
}
diff --git a/mlir/test/Target/SPIRV/group-ops.mlir b/mlir/test/Target/SPIRV/group-ops.mlir
index 6442e00492e00..27d917b3d49ac 100644
--- a/mlir/test/Target/SPIRV/group-ops.mlir
+++ b/mlir/test/Target/SPIRV/group-ops.mlir
@@ -9,14 +9,14 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
}
// CHECK-LABEL: @group_broadcast_1
spv.func @group_broadcast_1(%value: f32, %localid: i32 ) -> f32 "None" {
- // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32
- %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, i32
+ // CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, i32
+ %0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, i32
spv.ReturnValue %0: f32
}
// CHECK-LABEL: @group_broadcast_2
spv.func @group_broadcast_2(%value: f32, %localid: vector<3xi32> ) -> f32 "None" {
- // CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, vector<3xi32>
- %0 = spv.GroupBroadcast Workgroup %value, %localid : f32, vector<3xi32>
+ // CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, vector<3xi32>
+ %0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, vector<3xi32>
spv.ReturnValue %0: f32
}
// CHECK-LABEL: @subgroup_block_read_intel
diff --git a/mlir/test/Target/SPIRV/non-uniform-ops.mlir b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
index e0d576f1b52cf..d429b20e378bd 100644
--- a/mlir/test/Target/SPIRV/non-uniform-ops.mlir
+++ b/mlir/test/Target/SPIRV/non-uniform-ops.mlir
@@ -3,23 +3,23 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @group_non_uniform_ballot
spv.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> "None" {
- // CHECK: %{{.*}} = spv.GroupNonUniformBallot Workgroup %{{.*}}: vector<4xi32>
- %0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
+ // CHECK: %{{.*}} = spv.GroupNonUniformBallot <Workgroup> %{{.*}}: vector<4xi32>
+ %0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
// CHECK-LABEL: @group_non_uniform_broadcast
spv.func @group_non_uniform_broadcast(%value: f32) -> f32 "None" {
%one = spv.Constant 1 : i32
- // CHECK: spv.GroupNonUniformBroadcast Subgroup %{{.*}}, %{{.*}} : f32, i32
- %0 = spv.GroupNonUniformBroadcast Subgroup %value, %one : f32, i32
+ // CHECK: spv.GroupNonUniformBroadcast <Subgroup> %{{.*}}, %{{.*}} : f32, i32
+ %0 = spv.GroupNonUniformBroadcast <Subgroup> %value, %one : f32, i32
spv.ReturnValue %0: f32
}
// CHECK-LABEL: @group_non_uniform_elect
spv.func @group_non_uniform_elect() -> i1 "None" {
- // CHECK: %{{.+}} = spv.GroupNonUniformElect Workgroup : i1
- %0 = spv.GroupNonUniformElect Workgroup : i1
+ // CHECK: %{{.+}} = spv.GroupNonUniformElect <Workgroup> : i1
+ %0 = spv.GroupNonUniformElect <Workgroup> : i1
spv.ReturnValue %0: i1
}
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index b1fc3c177bcbf..1db8c96bdc99c 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -519,10 +519,24 @@ static void emitAttributeSerialization(const Attribute &attr,
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
+ // These two enums are encoded as <id> to constant values in SPIR-V blob,
+ // but we directly use the constant value as attribute in SPIR-V dialect. So
+ // need to handle them separately from normal enum attributes.
+ EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
- "attr.cast<IntegerAttr>()));\n",
- operandList, opVar);
+ "Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
+ "attr.cast<{2}::{3}Attr>().getValue()))));\n",
+ operandList, opVar, baseEnum.getCppNamespace(),
+ baseEnum.getEnumClassName());
+ } else if (attr.isSubClassOf("SPV_BitEnumAttr") ||
+ attr.isSubClassOf("SPV_I32EnumAttr")) {
+ EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+ os << tabs
+ << formatv(" {0}.push_back(static_cast<uint32_t>("
+ "attr.cast<{1}::{2}Attr>().getValue()));\n",
+ operandList, baseEnum.getCppNamespace(),
+ baseEnum.getEnumClassName());
} else if (attr.getAttrDefName() == "I32ArrayAttr") {
// Serialize all the elements of the array
os << tabs << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
@@ -531,7 +545,7 @@ static void emitAttributeSerialization(const Attribute &attr,
"attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n",
operandList);
os << tabs << " }\n";
- } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
+ } else if (attr.getAttrDefName() == "I32Attr") {
os << tabs
<< formatv(" {0}.push_back(static_cast<uint32_t>("
"attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
@@ -797,10 +811,25 @@ static void emitAttributeDeserialization(const Attribute &attr,
raw_ostream &os) {
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
+ // These two enums are encoded as <id> to constant values in SPIR-V blob,
+ // but we directly use the constant value as attribute in SPIR-V dialect. So
+ // need to handle them separately from normal enum attributes.
+ EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
- "getConstantInt({2}[{3}++])));\n",
- attrList, attrName, words, wordIndex);
+ "opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
+ "getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
+ attrList, attrName, baseEnum.getCppNamespace(),
+ baseEnum.getEnumClassName(), words, wordIndex);
+ } else if (attr.isSubClassOf("SPV_BitEnumAttr") ||
+ attr.isSubClassOf("SPV_I32EnumAttr")) {
+ EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
+ os << tabs
+ << formatv(" {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
+ "opBuilder.getAttr<{2}::{3}Attr>("
+ "static_cast<{2}::{3}>({4}[{5}++]))));\n",
+ attrList, attrName, baseEnum.getCppNamespace(),
+ baseEnum.getEnumClassName(), words, wordIndex);
} else if (attr.getAttrDefName() == "I32ArrayAttr") {
os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
@@ -815,7 +844,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"opBuilder.getArrayAttr(attrListElems)));\n",
attrList, attrName);
- } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
+ } else if (attr.getAttrDefName() == "I32Attr") {
os << tabs
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
@@ -1257,11 +1286,12 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
for (const Availability &avail : opAvailabilities)
availClasses.try_emplace(avail.getClass(), avail);
for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
- const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
- if (!enumAttr)
+ if (!namedAttr.attr.isSubClassOf("SPV_BitEnumAttr") &&
+ !namedAttr.attr.isSubClassOf("SPV_I32EnumAttr"))
continue;
+ EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
- for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
+ for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
for (const Availability &caseAvail :
getAvailabilities(enumerant.getDef()))
availClasses.try_emplace(caseAvail.getClass(), caseAvail);
@@ -1298,16 +1328,17 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
// Update with enum attributes' specific availability spec.
for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
- const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
- if (!enumAttr)
+ if (!namedAttr.attr.isSubClassOf("SPV_BitEnumAttr") &&
+ !namedAttr.attr.isSubClassOf("SPV_I32EnumAttr"))
continue;
+ EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
// (enumerant, availability specification) pairs for this availability
// class.
SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
// Collect all cases' availability specs.
- for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
+ for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
for (const Availability &caseAvail :
getAvailabilities(enumerant.getDef()))
if (availClassName == caseAvail.getClass())
@@ -1318,19 +1349,19 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
if (caseSpecs.empty())
continue;
- if (enumAttr->isBitEnum()) {
+ if (enumAttr.isBitEnum()) {
// For BitEnumAttr, we need to iterate over each bit to query its
// availability spec.
os << formatv(" for (unsigned i = 0; "
"i < std::numeric_limits<{0}>::digits; ++i) {{\n",
- enumAttr->getUnderlyingType());
+ enumAttr.getUnderlyingType());
os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
"static_cast<{0}::{1}>(1 << i);\n",
- enumAttr->getCppNamespace(), enumAttr->getEnumClassName(),
+ enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
namedAttr.name);
os << formatv(
" if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
- enumAttr->getUnderlyingType());
+ enumAttr.getUnderlyingType());
} else {
// For IntEnumAttr, we just need to query the value as a whole.
os << " {\n";
@@ -1338,7 +1369,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
namedAttr.name);
}
os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
- enumAttr->getCppNamespace(), avail.getQueryFnName());
+ enumAttr.getCppNamespace(), avail.getQueryFnName());
os << " if (tblgen_instance) "
// TODO` here once ODS supports
// dialect-specific contents so that we can use not implementing the
@@ -1385,7 +1416,8 @@ static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Capability Implication", os);
- EnumAttr enumAttr(recordKeeper.getDef("SPV_CapabilityAttr"));
+ EnumAttr enumAttr(
+ recordKeeper.getDef("SPV_CapabilityAttr")->getValueAsDef("enum"));
os << "ArrayRef<spirv::Capability> "
"spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index f17bc53c5bd8b..f7a1db0749c57 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -14,6 +14,7 @@
#include "mlir/Target/SPIRV/Serialization.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
@@ -46,11 +47,10 @@ class SerializationTest : public ::testing::Test {
OperationState state(UnknownLoc::get(&context),
spirv::ModuleOp::getOperationName());
state.addAttribute("addressing_model",
- builder.getI32IntegerAttr(static_cast<uint32_t>(
- spirv::AddressingModel::Logical)));
- state.addAttribute("memory_model",
- builder.getI32IntegerAttr(
- static_cast<uint32_t>(spirv::MemoryModel::GLSL450)));
+ builder.getAttr<spirv::AddressingModelAttr>(
+ spirv::AddressingModel::Logical));
+ state.addAttribute("memory_model", builder.getAttr<spirv::MemoryModelAttr>(
+ spirv::MemoryModel::GLSL450));
state.addAttribute("vce_triple",
spirv::VerCapExtAttr::get(
spirv::Version::V_1_0, ArrayRef<spirv::Capability>(),
diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py
index 7b0a61c2379b9..c0b145e17cedb 100755
--- a/mlir/utils/spirv/gen_spirv_dialect.py
+++ b/mlir/utils/spirv/gen_spirv_dialect.py
@@ -437,10 +437,13 @@ def get_case_symbol(kind_name, case_name):
# Generate the enum attribute definition
kind_category = 'Bit' if is_bit_enum else 'I32'
enum_attr = '''def SPV_{name}Attr :
- SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", [
+ SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [
{cases}
]>;'''.format(
- name=kind_name, category=kind_category, cases=case_names)
+ name=kind_name,
+ snake_name=snake_casify(kind_name),
+ category=kind_category,
+ cases=case_names)
return kind_name, case_defs + '\n\n' + enum_attr
@@ -473,7 +476,8 @@ def gen_opcode(instructions):
]
opcode_list = ',\n'.join(opcode_list)
enum_attr = 'def SPV_OpcodeAttr :\n'\
- ' SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\
+ ' SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '\
+ '"opcode", [\n'\
'{lst}\n'\
' ]>;'.format(name='Opcode', lst=opcode_list)
return opcode_str + '\n\n' + enum_attr
@@ -630,9 +634,7 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
def snake_casify(name):
"""Turns the given name to follow snake_case convention."""
- name = re.sub('\W+', '', name).split()
- name = [s.lower() for s in name]
- return '_'.join(name)
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
def map_spec_operand_to_ods_argument(operand):
More information about the Mlir-commits
mailing list