[Mlir-commits] [mlir] [mlir][XeGPUToXeVM][XeVMToLLVM] Update the default cache-control values. (PR #190954)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 8 03:55:47 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Md Abdullah Shahneous Bari (mshahneo)

<details>
<summary>Changes</summary>

In XeGPU, user can provide cache control value for only certain levels, it does not have to provide cache-control value for all levels. However, XeVM requires that each cache level must have a cache-control value. To solve this issue, XeGPUToXeVM conversion pass adds default value to cache-levels that does not have a user-provided cache-control value.

This PR updates this process in the following way:
- If no cache-control hint is provided for any levels, use system default (USE_DEFAULT).
- If any of the levels has at least one cache-control value, use certain default value for other levels.

If system default (USE_DEFAULT) is used, no cache-control metadata annotation is necessary.

---
Full diff: https://github.com/llvm/llvm-project/pull/190954.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td (+9-2) 
- (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+11-3) 
- (modified) mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp (+26) 
- (modified) mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir (+34-2) 
- (modified) mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir (+29) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
index 898d65e4e7ec2..f39d5e069d26d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td
@@ -80,6 +80,8 @@ def XeVM_1DBlockElemType : AnyTypeOf<[I8, I16, I32, I64]>;
 //   ir - invalidated after read
 //===----------------------------------------------------------------------===//
 
+def LoadCacheControl_Use_Default
+    : I32EnumAttrCase<"USE_DEFAULT", 0, "Use_Default">;
 def LoadCacheControl_L1uc_L2uc_L3uc
     : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
 def LoadCacheControl_L1uc_L2uc_L3c
@@ -110,7 +112,8 @@ def LoadCacheControlInvalidateRead
 def XeVM_LoadCacheControl
     : I32EnumAttr<
           "LoadCacheControl", "XeVM load ops cache control",
-          [LoadCacheControl_L1uc_L2uc_L3uc, LoadCacheControl_L1uc_L2uc_L3c,
+          [LoadCacheControl_Use_Default,
+           LoadCacheControl_L1uc_L2uc_L3uc, LoadCacheControl_L1uc_L2uc_L3c,
            LoadCacheControl_L1uc_L2c_L3uc, LoadCacheControl_L1uc_L2c_L3c,
            LoadCacheControl_L1c_L2uc_L3uc, LoadCacheControl_L1c_L2uc_L3c,
            LoadCacheControl_L1c_L2c_L3uc, LoadCacheControl_L1c_L2c_L3c,
@@ -136,6 +139,9 @@ def XeVM_LoadCacheControlAttr
 //   s - streaming
 //===----------------------------------------------------------------------===//
 
+
+def StoreCacheControl_Use_Default
+    : I32EnumAttrCase<"USE_DEFAULT", 0, "Use_Default">;
 def StoreCacheControl_L1uc_L2uc_L3uc
     : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">;
 def StoreCacheControl_L1uc_L2uc_L3wb
@@ -170,7 +176,8 @@ def StoreCacheControl_L1wb_L2uc_L3wb
 def XeVM_StoreCacheControl
     : I32EnumAttr<
           "StoreCacheControl", "XeVM store ops cache control",
-          [StoreCacheControl_L1uc_L2uc_L3uc, StoreCacheControl_L1uc_L2uc_L3wb,
+          [StoreCacheControl_Use_Default,
+           StoreCacheControl_L1uc_L2uc_L3uc, StoreCacheControl_L1uc_L2uc_L3wb,
            StoreCacheControl_L1uc_L2wb_L3uc, StoreCacheControl_L1uc_L2wb_L3wb,
            StoreCacheControl_L1wt_L2uc_L3uc, StoreCacheControl_L1wt_L2uc_L3wb,
            StoreCacheControl_L1wt_L2wb_L3uc, StoreCacheControl_L1wt_L2wb_L3wb,
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index daee02990ee94..4816ec693ae77 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -93,8 +93,12 @@ static VectorType encodeVectorTypeTo(VectorType currentVecType,
 static xevm::LoadCacheControl
 translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
                             std::optional<xegpu::CachePolicy> L3hint) {
-  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
-  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+  // If no hints are provided, use the default cache control.
+  if (!L1hint && !L3hint)
+    return xevm::LoadCacheControl::USE_DEFAULT;
+  // If only one of the hints is provided, use the default for the other level.
+  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
+  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
   switch (L1hintVal) {
   case xegpu::CachePolicy::CACHED:
     if (L3hintVal == xegpu::CachePolicy::CACHED)
@@ -127,8 +131,12 @@ translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
 static xevm::StoreCacheControl
 translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
                              std::optional<xegpu::CachePolicy> L3hint) {
+  // If no hints are provided, use the default cache control.
+  if (!L1hint && !L3hint)
+    return xevm::StoreCacheControl::USE_DEFAULT;
+  // If only one of the hints is provided, use the default for the other level.
   auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
-  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
+  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
   switch (L1hintVal) {
   case xegpu::CachePolicy::UNCACHED:
     if (L3hintVal == xegpu::CachePolicy::UNCACHED)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index e6acc0525fdd5..cb86d987edb54 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -102,6 +102,9 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
 static int32_t getL1CacheControl(LoadCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
+  case LoadCacheControl::USE_DEFAULT:
+    control = -1;
+    break;
   case LoadCacheControl::L1C_L2UC_L3UC:
   case LoadCacheControl::L1C_L2UC_L3C:
   case LoadCacheControl::L1C_L2C_L3UC:
@@ -126,6 +129,9 @@ static int32_t getL1CacheControl(LoadCacheControl cc) {
 static int32_t getL1CacheControl(StoreCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
+  case StoreCacheControl::USE_DEFAULT:
+    control = -1;
+    break;
   case StoreCacheControl::L1WT_L2UC_L3UC:
   case StoreCacheControl::L1WT_L2UC_L3WB:
   case StoreCacheControl::L1WT_L2WB_L3UC:
@@ -152,6 +158,9 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
 static int32_t getL3CacheControl(LoadCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
+  case LoadCacheControl::USE_DEFAULT:
+    control = -1;
+    break;
   case LoadCacheControl::L1UC_L2UC_L3C:
   case LoadCacheControl::L1UC_L2C_L3C:
   case LoadCacheControl::L1C_L2UC_L3C:
@@ -172,6 +181,9 @@ static int32_t getL3CacheControl(LoadCacheControl cc) {
 static int32_t getL3CacheControl(StoreCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
+  case StoreCacheControl::USE_DEFAULT:
+    control = -1;
+    break;
   case StoreCacheControl::L1UC_L2UC_L3WB:
   case StoreCacheControl::L1UC_L2WB_L3WB:
   case StoreCacheControl::L1WT_L2UC_L3WB:
@@ -246,6 +258,7 @@ static std::optional<ArrayAttr>
 getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
   if (!getCacheControl(op))
     return {};
+
   constexpr int32_t decorationCacheControlArity{3};
   constexpr int32_t loadCacheControlKey{6442};
   constexpr int32_t storeCacheControlKey{6443};
@@ -254,6 +267,19 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
                           std::is_same_v<OpType, LLVM::LoadOp> ||
                           std::is_same_v<OpType, BlockLoadOp> ||
                           std::is_same_v<OpType, PrefetchOp>;
+
+  // If the cache control is USE_DEFAULT, then we don’t emit any metadata.
+  // Assert that if one of the L1 or L3 cache control values is USE_DEFAULT
+  // (represented as -1), then both must be USE_DEFAULT; otherwise there is a
+  // bug.
+  assert(((getL1CacheControl<OpType>(op) == -1) ==
+          (getL3CacheControl<OpType>(op) == -1)) &&
+         "If one of L1 or L3 cache control is USE_DEFAULT, both must be "
+         "USE_DEFAULT");
+
+  if (getL1CacheControl<OpType>(op) == -1 &&
+      getL3CacheControl<OpType>(op) == -1)
+    return {};
   const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
   SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
       controlKey, 0, getL1CacheControl<OpType>(op)};
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index 4c73c9c238b6e..60d91bc1e7ceb 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -12,7 +12,7 @@ gpu.module @load_store_check {
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 
         //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
-        //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3c>, elem_size_in_bits = 32 : i32,
         //CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
         //CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
         %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
@@ -26,10 +26,42 @@ gpu.module @load_store_check {
         %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
 
         //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
-        //CHECK-SAME:   cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        //CHECK-SAME:   cache_control = #xevm.store_cache_control<L1wb_L2uc_L3wb>, elem_size_in_bits = 32 : i32,
         //CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
         xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
             : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
         gpu.return
     }
+
+    // CHECK-LABEL: gpu.func @load_store_with_partial_cache_hints(
+    gpu.func @load_store_with_partial_cache_hints(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+        // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32
+        // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
+        // CHECK: %[[H:.*]] = arith.constant 8 : i32
+        %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+        %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
+
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+        //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
+        //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3c>, elem_size_in_bits = 32 : i32,
+        //CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+        //CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+        %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>}>
+            : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+
+        %tid_x = gpu.thread_id x
+        %tid_x_i32 = arith.index_cast %tid_x : index to i32
+        %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+        %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
+
+        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+
+        //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
+        //CHECK-SAME:   cache_control = #xevm.store_cache_control<Use_Default>, elem_size_in_bits = 32 : i32,
+        //CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+        xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{}>
+            : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+        gpu.return
+    }
 }
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index 6fe7f8411d059..c7981a897b8fe 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -27,6 +27,35 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
       pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
   llvm.return %loaded_a : vector<8xi16>
 }
+// -----
+// CHECK-LABEL:      llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
+// CHECK-SAME:   !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
+// CHECK-SAME:   !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
+// CHECK:      llvm.func @blockload2d_default_cache_control(%[[ARG0:.*]]: !llvm.ptr<1>,
+// CHECK-SAME:   %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32)
+
+llvm.func @blockload2d_default_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
+  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.undef : vector<2xi32>
+  // CHECK: %[[VAR1:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
+  // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
+  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+  // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
+  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
+  // CHECK-SAME:   linkage = #llvm.linkage<external>, no_unwind, sym_name =
+  // CHECK-SAME:   "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
+  // CHECK-SAME:   will_return} :
+  // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
+  // CHECK-SAME:  !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
+  // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
+  %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
+    <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
+      pack_register=false, cache_control=#xevm.load_cache_control<Use_Default>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
+  llvm.return %loaded_a : vector<8xi16>
+}
 
 // -----
 // CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(

``````````

</details>


https://github.com/llvm/llvm-project/pull/190954


More information about the Mlir-commits mailing list