[Mlir-commits] [mlir] 4f78f85 - [MLIR][SPIRV] Add definition and (de)serialization for cache controls (#115461)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 18 00:42:35 PST 2024
Author: Victor Perez
Date: 2024-11-18T09:42:31+01:00
New Revision: 4f78f8519056953d26102c7426fbb028caf13bc9
URL: https://github.com/llvm/llvm-project/commit/4f78f8519056953d26102c7426fbb028caf13bc9
DIFF: https://github.com/llvm/llvm-project/commit/4f78f8519056953d26102c7426fbb028caf13bc9.diff
LOG: [MLIR][SPIRV] Add definition and (de)serialization for cache controls (#115461)
[SPV_INTEL_cache_controls](https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_cache_controls.html)
defines decorations for load and store cache control. Add support for
this extension in the SPIR-V dialect.
As several `CacheControlLoadINTEL` and `CacheControlStoreINTEL` may be
applied to the same value, these are represented as array attributes.
(De)Serialization takes care of this representation.
---------
Signed-off-by: Victor Perez <victor.perez at codeplay.com>
Added:
Modified:
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
mlir/test/Target/SPIRV/decorations.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
index f2a12f68d481b8..1bc3c63646fdd6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td
@@ -97,6 +97,20 @@ def SPIRV_CooperativeMatrixPropertiesNVAttr :
let assemblyFormat = "`<` struct(params) `>`";
}
+def SPIRV_CacheControlLoadINTELAttr :
+ SPIRV_Attr<"CacheControlLoadINTEL", "cache_control_load_intel"> {
+ let parameters = (ins "unsigned":$cache_level,
+ "mlir::spirv::LoadCacheControl":$load_cache_control);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def SPIRV_CacheControlStoreINTELAttr :
+ SPIRV_Attr<"CacheControlStoreINTEL", "cache_control_store_intel"> {
+ let parameters = (ins "unsigned":$cache_level,
+ "mlir::spirv::StoreCacheControl":$store_cache_control);
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
def SPIRV_CooperativeMatrixPropertiesNVArrayAttr :
TypedArrayAttrBase<SPIRV_CooperativeMatrixPropertiesNVAttr,
"CooperativeMatrixPropertiesNV array attribute">;
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 3b7da9b44a08fb..252d9319fccc5a 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -400,6 +400,7 @@ def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp
def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
+def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -459,7 +460,8 @@ def SPIRV_ExtensionAttr :
SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier,
- SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
+ SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls,
+ SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
@@ -1415,6 +1417,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
];
}
+def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
+ list<Availability> availability = [
+ Extension<[SPV_INTEL_cache_controls]>
+ ];
+}
+
def SPIRV_CapabilityAttr :
SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16,
@@ -1507,7 +1515,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_UniformTexelBufferArrayNonUniformIndexing,
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
- SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL
+ SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
+ SPIRV_C_CacheControlsINTEL
]>;
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -2623,6 +2632,16 @@ def SPIRV_D_MediaBlockIOINTEL : I32EnumAttrCase<"MediaBlockIOIN
Capability<[SPIRV_C_VectorComputeINTEL]>
];
}
+def SPIRV_D_CacheControlLoadINTEL : I32EnumAttrCase<"CacheControlLoadINTEL", 6442> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_CacheControlsINTEL]>
+ ];
+}
+def SPIRV_D_CacheControlStoreINTEL : I32EnumAttrCase<"CacheControlStoreINTEL", 6443> {
+ list<Availability> availability = [
+ Capability<[SPIRV_C_CacheControlsINTEL]>
+ ];
+}
def SPIRV_DecorationAttr :
SPIRV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", "decoration", [
@@ -2658,7 +2677,8 @@ def SPIRV_DecorationAttr :
SPIRV_D_FuseLoopsInFunctionINTEL, SPIRV_D_AliasScopeINTEL, SPIRV_D_NoAliasINTEL,
SPIRV_D_BufferLocationINTEL, SPIRV_D_IOPipeStorageINTEL,
SPIRV_D_FunctionFloatingPointModeINTEL, SPIRV_D_SingleElementVectorINTEL,
- SPIRV_D_VectorComputeCallableFunctionINTEL, SPIRV_D_MediaBlockIOINTEL
+ SPIRV_D_VectorComputeCallableFunctionINTEL, SPIRV_D_MediaBlockIOINTEL,
+ SPIRV_D_CacheControlLoadINTEL, SPIRV_D_CacheControlStoreINTEL
]>;
def SPIRV_D_1D : I32EnumAttrCase<"Dim1D", 0> {
@@ -4092,6 +4112,32 @@ def SPIRV_KHR_CooperativeMatrixOperandsAttr :
SPIRV_KHR_CMO_Result_Signed, SPIRV_KHR_CMO_AccSat
]>;
+def SPIRV_INTEL_LCC_Uncached : I32EnumAttrCase<"Uncached", 0>;
+def SPIRV_INTEL_LCC_Cached : I32EnumAttrCase<"Cached", 1>;
+def SPIRV_INTEL_LCC_Streaming : I32EnumAttrCase<"Streaming", 2>;
+def SPIRV_INTEL_LCC_InvalidateAfterRead : I32EnumAttrCase<"InvalidateAfterR", 3>;
+def SPIRV_INTEL_LCC_ConstCached : I32EnumAttrCase<"ConstCached", 4>;
+
+def SPIRV_INTEL_LoadCacheControlAttr :
+ SPIRV_I32EnumAttr<"LoadCacheControl", "valid SPIR-V LoadCacheControl",
+ "load_cache_control", [
+ SPIRV_INTEL_LCC_Uncached, SPIRV_INTEL_LCC_Cached,
+ SPIRV_INTEL_LCC_Streaming, SPIRV_INTEL_LCC_InvalidateAfterRead,
+ SPIRV_INTEL_LCC_ConstCached
+ ]>;
+
+def SPIRV_INTEL_SCC_Uncached : I32EnumAttrCase<"Uncached", 0>;
+def SPIRV_INTEL_SCC_WriteThrough : I32EnumAttrCase<"WriteThrough", 1>;
+def SPIRV_INTEL_SCC_WriteBack : I32EnumAttrCase<"WriteBack", 2>;
+def SPIRV_INTEL_SCC_Streaming : I32EnumAttrCase<"Streaming", 3>;
+
+def SPIRV_INTEL_StoreCacheControlAttr :
+ SPIRV_I32EnumAttr<"StoreCacheControl", "valid SPIR-V StoreCacheControl",
+ "store_cache_control", [
+ SPIRV_INTEL_SCC_Uncached, SPIRV_INTEL_SCC_WriteThrough,
+ SPIRV_INTEL_SCC_WriteBack, SPIRV_INTEL_SCC_Streaming
+ ]>;
+
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 462d3e326b6c27..04469f1933819b 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -226,6 +226,28 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
return success();
}
+template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
+LogicalResult deserializeCacheControlDecoration(
+ Location loc, OpBuilder &opBuilder,
+ DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words,
+ StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
+ if (words.size() != 4) {
+ return emitError(loc, "OpDecoration with ")
+ << decorationName << "needs a cache control integer literal and a "
+ << cacheControlKind << " cache control literal";
+ }
+ unsigned cacheLevel = words[2];
+ auto cacheControlAttr = static_cast<EnumTy>(words[3]);
+ auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
+ SmallVector<Attribute> attrs;
+ if (auto attrList =
+ llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
+ llvm::append_range(attrs, attrList);
+ attrs.push_back(value);
+ decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
+ return success();
+}
+
LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
// TODO: This function should also be auto-generated. For now, since only a
// few decorations are processed/handled in a meaningful manner, going with a
@@ -339,6 +361,24 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
decorations[words[0]].set(
symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
break;
+ case spirv::Decoration::CacheControlLoadINTEL: {
+ LogicalResult res = deserializeCacheControlDecoration<
+ CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
+ unknownLoc, opBuilder, decorations, words, symbol, decorationName,
+ "load");
+ if (failed(res))
+ return res;
+ break;
+ }
+ case spirv::Decoration::CacheControlStoreINTEL: {
+ LogicalResult res = deserializeCacheControlDecoration<
+ CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
+ unknownLoc, opBuilder, decorations, words, symbol, decorationName,
+ "store");
+ if (failed(res))
+ return res;
+ break;
+ }
default:
return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index f355982e9ed884..1f4f5d7f764db3 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -217,10 +217,42 @@ static std::string getDecorationName(StringRef attrName) {
// similar here
if (attrName == "fp_rounding_mode")
return "FPRoundingMode";
+ // convertToCamelFromSnakeCase will not capitalize "INTEL".
+ if (attrName == "cache_control_load_intel")
+ return "CacheControlLoadINTEL";
+ if (attrName == "cache_control_store_intel")
+ return "CacheControlStoreINTEL";
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
}
+template <typename AttrTy, typename EmitF>
+LogicalResult processDecorationList(Location loc, Decoration decoration,
+ Attribute attrList, StringRef attrName,
+ EmitF emitter) {
+ auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
+ if (!arrayAttr) {
+ return emitError(loc, "expecting array attribute of ")
+ << attrName << " for " << stringifyDecoration(decoration);
+ }
+ if (arrayAttr.empty()) {
+ return emitError(loc, "expecting non-empty array attribute of ")
+ << attrName << " for " << stringifyDecoration(decoration);
+ }
+ for (Attribute attr : arrayAttr.getValue()) {
+ auto cacheControlAttr = dyn_cast<AttrTy>(attr);
+ if (!cacheControlAttr) {
+ return emitError(loc, "expecting array attribute of ")
+ << attrName << " for " << stringifyDecoration(decoration);
+ }
+ // This named attribute encodes several decorations. Emit one per
+ // element in the array.
+ if (failed(emitter(cacheControlAttr)))
+ return failure();
+ }
+ return success();
+}
+
LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
Decoration decoration,
Attribute attr) {
@@ -294,6 +326,26 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
return emitError(loc,
"expected unit attribute or decoration attribute for ")
<< stringifyDecoration(decoration);
+ case spirv::Decoration::CacheControlLoadINTEL:
+ return processDecorationList<CacheControlLoadINTELAttr>(
+ loc, decoration, attr, "CacheControlLoadINTEL",
+ [&](CacheControlLoadINTELAttr attr) {
+ unsigned cacheLevel = attr.getCacheLevel();
+ LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
+ return emitDecoration(
+ resultID, decoration,
+ {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
+ });
+ case spirv::Decoration::CacheControlStoreINTEL:
+ return processDecorationList<CacheControlStoreINTELAttr>(
+ loc, decoration, attr, "CacheControlStoreINTEL",
+ [&](CacheControlStoreINTELAttr attr) {
+ unsigned cacheLevel = attr.getCacheLevel();
+ StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
+ return emitDecoration(
+ resultID, decoration,
+ {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
+ });
default:
return emitError(loc, "unhandled decoration ")
<< stringifyDecoration(decoration);
diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
index 53a1015de75bcc..66c70e816d4134 100644
--- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir
@@ -69,3 +69,21 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
%0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
spirv.Return
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INTEL.CacheControls
+//===----------------------------------------------------------------------===//
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @foo() "None" {
+ // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 0a29290b6a6fab..d66ac74dc4ef9b 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip -verify-diagnostics %s | FileCheck %s
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: location = 0 : i32
@@ -107,3 +107,47 @@ spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
spirv.ReturnValue %0 : f16
}
}
+
+// -----
+
+// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls() "None" {
+ // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls_invalid_type() "None" {
+ // expected-error at below {{expecting array attribute of CacheControlLoadINTEL for CacheControlLoadINTEL}}
+ %0 = spirv.Variable {cache_control_load_intel = #spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls_invalid_type() "None" {
+ // expected-error at below {{expecting array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
+ %0 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, 0 : i32]} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls_invalid_type() "None" {
+ // expected-error at below {{expecting non-empty array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
+ %0 = spirv.Variable {cache_control_store_intel = []} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
More information about the Mlir-commits
mailing list