[Mlir-commits] [mlir] [MLIR][XeGPU] Add layout propagation for `xegpu.store_matrix` (PR #174952)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 8 03:50:31 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Artem Kroviakov (akroviakov)

<details>
<summary>Changes</summary>

This PR adds layout propagation for `xegpu.store_matrix` op. The op is not tightly coupled with a specific intrinsic, so inst_data setting is not entirely clear at this stage. We use `subgroup_block_io` [only at the subgroup and lane levels](https://mlir.llvm.org/docs/Dialects/XeGPU/#xegpustore_matrix-xegpustorematrixop), but inst_data is "consumed" before subgroup distribution. 
What uArch info should be used for it? The same as for [scatter store](https://github.com/llvm/llvm-project/pull/172845), meaning `min(datavec_bitsize, uarch_val)`? 

Pinging @<!-- -->Jianhui-Li  as the author of the current `xegpu.store_matrix` op.

---
Full diff: https://github.com/llvm/llvm-project/pull/174952.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+29) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+12) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+17-4) 


``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 3ac23f348f8a9..dc83aaec14a95 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -431,6 +431,10 @@ class LayoutInfoPropagation
                         ArrayRef<LayoutInfoLattice *> operands,
                         ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
+                          ArrayRef<LayoutInfoLattice *> operands,
+                          ArrayRef<const LayoutInfoLattice *> results);
+
   bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
 
 public:
@@ -504,6 +508,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
         visitShapeCastOp(shapeCastOp, operands, results);
       })
+      .Case<xegpu::StoreMatrixOp>([&](auto storeMatrixOp) {
+        visitStoreMatrixOp(storeMatrixOp, operands, results);
+      })
       // All other ops.
       .Default([&](Operation *op) {
         for (const LayoutInfoLattice *resultInfo : results) {
@@ -1100,6 +1107,28 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
 }
 
+void LayoutInfoPropagation::visitStoreMatrixOp(
+    xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  Value operand = storeMatrix.getData();
+  unsigned index =
+      std::distance(storeMatrix.operand_begin(),
+                    llvm::find(storeMatrix->getOperands(), operand));
+
+  auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+  const int subgroupSize = uArch->getSubgroupSize();
+  SmallVector<int> instData = {1, 8};
+  LayoutInfo layout;
+  if (layoutKind == LayoutKind::InstData)
+    layout =
+        LayoutInfo(xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
+  else
+    layout =
+        getDefaultSIMTLayoutInfo(storeMatrix->getContext(), 2, subgroupSize);
+
+  propagateIfChanged(operands[index], operands[index]->meet(layout));
+}
+
 namespace {
 //===----------------------------------------------------------------------===//
 // RunLayoutInfoPropagation
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5f70831f45e97..ac0b35b43c59d 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -153,3 +153,15 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   return
 }
 }
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.mem_desc<16x64xf16>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 8]>} dense<0.000000e+00> : vector<16x16xf16>
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+  xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index b88d8e1a78a26..64e36373b7943 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -32,7 +32,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
 gpu.module @test {
 // CHECK-LABEL: func.func @dpas_i8(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} 
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
 
 func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
@@ -109,7 +109,7 @@ gpu.module @test {
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> 
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
@@ -240,7 +240,7 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] 
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
 // CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
 // CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
@@ -697,4 +697,17 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
   xegpu.store_nd %6, %arg0  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
   return
 }
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK:         %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK-NEXT:     xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+  xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/174952


More information about the Mlir-commits mailing list