[Mlir-commits] [mlir] [MLIR][XeGPU] Add layout propagation for `xegpu.store_matrix` (PR #174952)
Artem Kroviakov
llvmlistbot at llvm.org
Wed Jan 21 04:37:23 PST 2026
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/174952
>From 1f17d603ef49ebbbe520523e2433af3f9427ad04 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 8 Jan 2026 11:35:19 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU] Add layout propagation for
`xegpu.store_matrix`
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 29 +++++++++++++++++++
.../XeGPU/propagate-layout-inst-data.mlir | 12 ++++++++
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 21 +++++++++++---
3 files changed, 58 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 1341fc21e7fd4..50bbe1f0b19fa 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -420,6 +420,10 @@ class LayoutInfoPropagation
ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
+ void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
+ ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results);
+
bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
public:
@@ -493,6 +497,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
.Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
visitShapeCastOp(shapeCastOp, operands, results);
})
+ .Case<xegpu::StoreMatrixOp>([&](auto storeMatrixOp) {
+ visitStoreMatrixOp(storeMatrixOp, operands, results);
+ })
// All other ops.
.Default([&](Operation *op) {
for (const LayoutInfoLattice *resultInfo : results) {
@@ -1087,6 +1094,28 @@ void LayoutInfoPropagation::visitStoreScatterOp(
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
}
+void LayoutInfoPropagation::visitStoreMatrixOp(
+ xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
+ ArrayRef<const LayoutInfoLattice *> results) {
+ Value operand = storeMatrix.getData();
+ unsigned index =
+ std::distance(storeMatrix.operand_begin(),
+ llvm::find(storeMatrix->getOperands(), operand));
+
+ auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ SmallVector<int> instData = {1, 8};
+ LayoutInfo layout;
+ if (layoutKind == LayoutKind::InstData)
+ layout =
+ LayoutInfo(xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
+ else
+ layout =
+ getDefaultSIMTLayoutInfo(storeMatrix->getContext(), 2, subgroupSize);
+
+ propagateIfChanged(operands[index], operands[index]->meet(layout));
+}
+
namespace {
//===----------------------------------------------------------------------===//
// RunLayoutInfoPropagation
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5f70831f45e97..ac0b35b43c59d 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -153,3 +153,15 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
return
}
}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.mem_desc<16x64xf16>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 8]>} dense<0.000000e+00> : vector<16x16xf16>
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+ xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index b88d8e1a78a26..64e36373b7943 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -32,7 +32,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
gpu.module @test {
// CHECK-LABEL: func.func @dpas_i8(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
%c0 = arith.constant 0 : index
@@ -109,7 +109,7 @@ gpu.module @test {
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
%c0 = arith.constant 0 : index
@@ -240,7 +240,7 @@ gpu.module @test {
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
// CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
@@ -697,4 +697,17 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
xegpu.store_nd %6, %arg0 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
return
}
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK-NEXT: xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+ xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+}
>From 88a49d377a71e217030c0e92e0e62e7a27f825a5 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 21 Jan 2026 12:37:08 +0000
Subject: [PATCH 2/2] Reuse scatter ops restrictions for store matrix
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 56 ++++++++++---------
.../XeGPU/propagate-layout-inst-data.mlir | 2 +-
mlir/test/Dialect/XeGPU/propagate-layout.mlir | 2 +-
3 files changed, 33 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 50bbe1f0b19fa..37e8173902a80 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -289,9 +289,9 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
/// Helper to get the default layout for 2D block operations.
template <typename Ty>
-static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
- const xegpu::uArch::uArch *uArch,
- unsigned packingSize) {
+static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
+ const xegpu::uArch::uArch *uArch,
+ unsigned packingSize) {
// Expecting a 1D or 2D vector.
assert((ty.getRank() == 1 || ty.getRank() == 2) &&
"Expected 1D or 2D vector.");
@@ -309,10 +309,8 @@ static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
}
/// Helper to get the default layout for a vector type.
-static LayoutInfo
-getSIMTLayoutInforForScatterIO(VectorType vectorTy,
- const xegpu::uArch::uArch *uArch,
- unsigned packingSize) {
+static LayoutInfo getSIMTLayoutInfoScatterIO(VectorType vectorTy,
+ const xegpu::uArch::uArch *uArch) {
// Expecting a 1D or 2D vector.
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
"Expected 1D or 2D vector.");
@@ -320,6 +318,7 @@ getSIMTLayoutInforForScatterIO(VectorType vectorTy,
assert(vectorTy.getElementType().isIntOrFloat() &&
"Expected int or float element type.");
// If the rank is 1, then return default layout for 1D vector.
+ const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
if (vectorTy.getRank() == 1)
return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
// Packing factor is determined by the element type bitwidth.
@@ -354,7 +353,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
}
// Otherwise, return the default layout for the vector type.
- return getSIMTLayoutInforForBlockIO(vectorTy, uArch, packingSize);
+ return getSIMTLayoutInfoBlockIO(vectorTy, uArch, packingSize);
}
//===----------------------------------------------------------------------===//
@@ -583,7 +582,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
prefetchLayout =
LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
else
- prefetchLayout = getSIMTLayoutInforForBlockIO(
+ prefetchLayout = getSIMTLayoutInfoBlockIO(
tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
prefetch.setLayoutAttr(
@@ -836,9 +835,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
storeLayout =
LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
else
- storeLayout = getSIMTLayoutInforForBlockIO(
- store.getValueType(), uArch,
- uArchInstruction->getPackedFormatBitSize());
+ storeLayout =
+ getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
+ uArchInstruction->getPackedFormatBitSize());
store.setLayoutAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
}
@@ -996,8 +995,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
loadLayout =
LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
else
- loadLayout = getSIMTLayoutInforForScatterIO(
- payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
+ loadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
// Mask operand should have 1D default layout.
maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
@@ -1073,8 +1071,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
"Expected the first dimension of 2D tensor descriptor to be "
"equal to "
"subgroup size.");
- payloadLayout = getSIMTLayoutInforForScatterIO(
- payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
+ payloadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
}
maskLayout =
@@ -1094,6 +1091,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
}
+// Store matrix is a flavor of scattered store for 2D shapes.
void LayoutInfoPropagation::visitStoreMatrixOp(
xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
@@ -1102,16 +1100,24 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
std::distance(storeMatrix.operand_begin(),
llvm::find(storeMatrix->getOperands(), operand));
- auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- SmallVector<int> instData = {1, 8};
+ xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
LayoutInfo layout;
- if (layoutKind == LayoutKind::InstData)
- layout =
- LayoutInfo(xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
- else
- layout =
- getDefaultSIMTLayoutInfo(storeMatrix->getContext(), 2, subgroupSize);
+ if (hasParamsOfLayoutKind(anchorLayout)) {
+ layout = LayoutInfo(anchorLayout);
+ } else {
+ VectorType payloadTy = llvm::cast<VectorType>(operand.getType());
+ assert(payloadTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+ auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+ const int subgroupSize = uArch->getSubgroupSize();
+ int instLaneVec =
+ std::min(static_cast<int>(payloadTy.getShape().back()), 16);
+ SmallVector<int> instData = {subgroupSize, instLaneVec};
+ if (layoutKind == LayoutKind::InstData)
+ layout = LayoutInfo(
+ xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
+ else
+ layout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
+ }
propagateIfChanged(operands[index], operands[index]->meet(layout));
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index ac0b35b43c59d..0dc138f554249 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -158,7 +158,7 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
gpu.module @test {
// CHECK-LABEL: func.func @store_matrix(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.mem_desc<16x64xf16>) {
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 8]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} dense<0.000000e+00> : vector<16x16xf16>
func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
%cst = arith.constant dense<0.0000> : vector<16x16xf16>
xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 64e36373b7943..bf6c5d992a47f 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -702,7 +702,7 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
// -----
gpu.module @test {
// CHECK-LABEL: func.func @store_matrix(
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} dense<0.000000e+00> : vector<16x16xf16>
// CHECK-NEXT: xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
More information about the Mlir-commits
mailing list