[Mlir-commits] [mlir] [MLIR][XeGPU] Add layout propagation for `xegpu.store_matrix` (PR #174952)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 8 03:50:30 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Artem Kroviakov (akroviakov)
<details>
<summary>Changes</summary>
This PR adds layout propagation for `xegpu.store_matrix` op. The op is not tightly coupled with a specific intrinsic, so inst_data setting is not entirely clear at this stage. We use `subgroup_block_io` [only at the subgroup and lane levels](https://mlir.llvm.org/docs/Dialects/XeGPU/#xegpustore_matrix-xegpustorematrixop), but inst_data is "consumed" before subgroup distribution.
What uArch info should be used for it? The same as for [scatter store](https://github.com/llvm/llvm-project/pull/172845), meaning `min(datavec_bitsize, uarch_val)`?
Pinging @<!-- -->Jianhui-Li as the author of the current `xegpu.store_matrix` op.
---
Full diff: https://github.com/llvm/llvm-project/pull/174952.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+29)
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+12)
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+17-4)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 3ac23f348f8a9..dc83aaec14a95 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -431,6 +431,10 @@ class LayoutInfoPropagation
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
public:
@@ -504,6 +508,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
visitShapeCastOp(shapeCastOp, operands, results);
})
+ .Case<xegpu::StoreMatrixOp>([&](auto storeMatrixOp) {
+ visitStoreMatrixOp(storeMatrixOp, operands, results);
+ })
// All other ops.
.Default([&](Operation *op) {
for (const LayoutInfoLattice *resultInfo : results) {
@@ -1100,6 +1107,28 @@ void LayoutInfoPropagation::visitStoreScatterOp(
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
}
+void LayoutInfoPropagation::visitStoreMatrixOp(
+ xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ Value operand = storeMatrix.getData();
+ unsigned index =
+ std::distance(storeMatrix.operand_begin(),
+ llvm::find(storeMatrix->getOperands(), operand));
+
+ auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> instData = {1, 8};
+ LayoutInfo layout;
+ if (layoutKind == LayoutKind::InstData)
+ layout =
+ LayoutInfo(xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
+ else
+ layout =
+ getDefaultSIMTLayoutInfo(storeMatrix->getContext(), 2, subgroupSize);
+
+ propagateIfChanged(operands[index], operands[index]->meet(layout));
+}
+
namespace {
//===----------------------------------------------------------------------===//
// RunLayoutInfoPropagation
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5f70831f45e97..ac0b35b43c59d 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -153,3 +153,15 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
return
}
}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.mem_desc<16x64xf16>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 8]>} dense<0.000000e+00> : vector<16x16xf16>
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+ xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index b88d8e1a78a26..64e36373b7943 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -32,7 +32,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
gpu.module @test {
// CHECK-LABEL: func.func @dpas_i8(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
@@ -109,7 +109,7 @@ gpu.module @test {
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
@@ -240,7 +240,7 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
// CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
@@ -697,4 +697,17 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
xegpu.store_nd %6, %arg0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
return
}
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK-NEXT: xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+ xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/174952
More information about the Mlir-commits
mailing list