[Mlir-commits] [mlir] [mlir][XeGPUToXeVM][XeVMToLLVM] Update the default cache-control values. (PR #190954)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 8 03:55:47 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Md Abdullah Shahneous Bari (mshahneo)
<details>
<summary>Changes</summary>
In XeGPU, user can provide cache control value for only certain levels, it does not have to provide cache-control value for all levels. However, XeVM requires that each cache level must have a cache-control value. To solve this issue, XeGPUToXeVM conversion pass adds default value to cache-levels that does not have a user-provided cache-control value.
This PR updates this process in the following way:
- If no cache-control hint is provided for any levels, use system default (USE_DEFAULT).
- If any of the levels has at least one cache-control value, use certain default value for other levels.
If system default (USE_DEFAULT) is used, no cache-control metadata annotation is necessary.
---
Full diff: https://github.com/llvm/llvm-project/pull/190954.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td (+9-2)
- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+11-3)
- (modified) mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp (+26)
- (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir (+34-2)
- (modified) mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir (+29)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 898d65e4e7ec2..f39d5e069d26d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -80,6 +80,8 @@ def XeVM_1DBlockElemType : AnyTypeOf<[I8, I16, I32, I64]>;
// ir - invalidated after read
//===----------------------------------------------------------------------===//
+def LoadCacheControl_Use_Default
+ : I32EnumAttrCase<"USE_DEFAULT", 0, "Use_Default">;
def LoadCacheControl_L1uc_L2uc_L3uc
: I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
def LoadCacheControl_L1uc_L2uc_L3c
@@ -110,7 +112,8 @@ def LoadCacheControlInvalidateRead
def XeVM_LoadCacheControl
: I32EnumAttr<
"LoadCacheControl", "XeVM load ops cache control",
- [LoadCacheControl_L1uc_L2uc_L3uc, LoadCacheControl_L1uc_L2uc_L3c,
+ [LoadCacheControl_Use_Default,
+ LoadCacheControl_L1uc_L2uc_L3uc, LoadCacheControl_L1uc_L2uc_L3c,
LoadCacheControl_L1uc_L2c_L3uc, LoadCacheControl_L1uc_L2c_L3c,
LoadCacheControl_L1c_L2uc_L3uc, LoadCacheControl_L1c_L2uc_L3c,
LoadCacheControl_L1c_L2c_L3uc, LoadCacheControl_L1c_L2c_L3c,
@@ -136,6 +139,9 @@ def XeVM_LoadCacheControlAttr
// s - streaming
//===----------------------------------------------------------------------===//
+
+def StoreCacheControl_Use_Default
+ : I32EnumAttrCase<"USE_DEFAULT", 0, "Use_Default">;
def StoreCacheControl_L1uc_L2uc_L3uc
: I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
def StoreCacheControl_L1uc_L2uc_L3wb
@@ -170,7 +176,8 @@ def StoreCacheControl_L1wb_L2uc_L3wb
def XeVM_StoreCacheControl
: I32EnumAttr<
"StoreCacheControl", "XeVM store ops cache control",
- [StoreCacheControl_L1uc_L2uc_L3uc, StoreCacheControl_L1uc_L2uc_L3wb,
+ [StoreCacheControl_Use_Default,
+ StoreCacheControl_L1uc_L2uc_L3uc, StoreCacheControl_L1uc_L2uc_L3wb,
StoreCacheControl_L1uc_L2wb_L3uc, StoreCacheControl_L1uc_L2wb_L3wb,
StoreCacheControl_L1wt_L2uc_L3uc, StoreCacheControl_L1wt_L2uc_L3wb,
StoreCacheControl_L1wt_L2wb_L3uc, StoreCacheControl_L1wt_L2wb_L3wb,
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index daee02990ee94..4816ec693ae77 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -93,8 +93,12 @@ static VectorType encodeVectorTypeTo(VectorType currentVecType,
static xevm::LoadCacheControl
translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
std::optional<xegpu::CachePolicy> L3hint) {
- auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
- auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ // If no hints are provided, use the default cache control.
+ if (!L1hint && !L3hint)
+ return xevm::LoadCacheControl::USE_DEFAULT;
+ // If only one of the hints is provided, use the default for the other level.
+ auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
switch (L1hintVal) {
case xegpu::CachePolicy::CACHED:
if (L3hintVal == xegpu::CachePolicy::CACHED)
@@ -127,8 +131,12 @@ translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
static xevm::StoreCacheControl
translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
std::optional<xegpu::CachePolicy> L3hint) {
+ // If no hints are provided, use the default cache control.
+ if (!L1hint && !L3hint)
+ return xevm::StoreCacheControl::USE_DEFAULT;
+ // If only one of the hints is provided, use the default for the other level.
auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
- auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+ auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
switch (L1hintVal) {
case xegpu::CachePolicy::UNCACHED:
if (L3hintVal == xegpu::CachePolicy::UNCACHED)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index e6acc0525fdd5..cb86d987edb54 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -102,6 +102,9 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
static int32_t getL1CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
+ case LoadCacheControl::USE_DEFAULT:
+ control = -1;
+ break;
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3UC:
@@ -126,6 +129,9 @@ static int32_t getL1CacheControl(LoadCacheControl cc) {
static int32_t getL1CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
+ case StoreCacheControl::USE_DEFAULT:
+ control = -1;
+ break;
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3UC:
@@ -152,6 +158,9 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
static int32_t getL3CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
+ case LoadCacheControl::USE_DEFAULT:
+ control = -1;
+ break;
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3C:
case LoadCacheControl::L1C_L2UC_L3C:
@@ -172,6 +181,9 @@ static int32_t getL3CacheControl(LoadCacheControl cc) {
static int32_t getL3CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
+ case StoreCacheControl::USE_DEFAULT:
+ control = -1;
+ break;
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3WB:
case StoreCacheControl::L1WT_L2UC_L3WB:
@@ -246,6 +258,7 @@ static std::optional<ArrayAttr>
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
if (!getCacheControl(op))
return {};
+
constexpr int32_t decorationCacheControlArity{3};
constexpr int32_t loadCacheControlKey{6442};
constexpr int32_t storeCacheControlKey{6443};
@@ -254,6 +267,19 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
std::is_same_v<OpType, LLVM::LoadOp> ||
std::is_same_v<OpType, BlockLoadOp> ||
std::is_same_v<OpType, PrefetchOp>;
+
+ // If the cache control is USE_DEFAULT, then we don’t emit any metadata.
+ // Assert that if one of the L1 or L3 cache control values is USE_DEFAULT
+ // (represented as -1), then both must be USE_DEFAULT; otherwise there is a
+ // bug.
+ assert(((getL1CacheControl<OpType>(op) == -1) ==
+ (getL3CacheControl<OpType>(op) == -1)) &&
+ "If one of L1 or L3 cache control is USE_DEFAULT, both must be "
+ "USE_DEFAULT");
+
+ if (getL1CacheControl<OpType>(op) == -1 &&
+ getL3CacheControl<OpType>(op) == -1)
+ return {};
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
controlKey, 0, getL1CacheControl<OpType>(op)};
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index 4c73c9c238b6e..60d91bc1e7ceb 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -12,7 +12,7 @@ gpu.module @load_store_check {
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
//CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
- //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3c>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
//CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
%loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
@@ -26,10 +26,42 @@ gpu.module @load_store_check {
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
//CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
- //CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: cache_control = #xevm.store_cache_control<L1wb_L2uc_L3wb>, elem_size_in_bits = 32 : i32,
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
gpu.return
}
+
+ // CHECK-LABEL: gpu.func @load_store_with_partial_cache_hints(
+ gpu.func @load_store_with_partial_cache_hints(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+ // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
+ // CHECK: %[[H:.*]] = arith.constant 8 : i32
+ %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+ %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
+
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
+ //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3c>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+ //CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+ %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>}>
+ : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+
+ %tid_x = gpu.thread_id x
+ %tid_x_i32 = arith.index_cast %tid_x : index to i32
+ %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+ %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
+
+ %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+
+ //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
+ //CHECK-SAME: cache_control = #xevm.store_cache_control<Use_Default>, elem_size_in_bits = 32 : i32,
+ //CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+ xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{}>
+ : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+ gpu.return
+ }
}
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index 6fe7f8411d059..c7981a897b8fe 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -27,6 +27,35 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
llvm.return %loaded_a : vector<8xi16>
}
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
+// CHECK-SAME: !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
+// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
+// CHECK: llvm.func @blockload2d_default_cache_control(%[[ARG0:.*]]: !llvm.ptr<1>,
+// CHECK-SAME: %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+
+llvm.func @blockload2d_default_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
+ // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
+ // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
+ // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
+ // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+ // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
+ // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+ // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
+ // CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
+ // CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
+ // CHECK-SAME: will_return} :
+ // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
+ // CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
+ // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
+ %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
+ <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
+ pack_register=false, cache_control=#xevm.load_cache_control<Use_Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+ llvm.return %loaded_a : vector<8xi16>
+}
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
``````````
</details>
https://github.com/llvm/llvm-project/pull/190954
More information about the Mlir-commits
mailing list