[Mlir-commits] [mlir] [MLIR][XeGPU] Add layout propagation for `xegpu.store_matrix` (PR #174952)

Artem Kroviakov llvmlistbot at llvm.org
Wed Jan 21 04:37:23 PST 2026


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/174952

>From 1f17d603ef49ebbbe520523e2433af3f9427ad04 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 8 Jan 2026 11:35:19 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU] Add layout propagation for
 `xegpu.store_matrix`

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 29 +++++++++++++++++++
 .../XeGPU/propagate-layout-inst-data.mlir     | 12 ++++++++
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 21 +++++++++++---
 3 files changed, 58 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 1341fc21e7fd4..50bbe1f0b19fa 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -420,6 +420,10 @@ class LayoutInfoPropagation
                         ArrayRef<LayoutInfoLattice *> operands,
                         ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
+                          ArrayRef<LayoutInfoLattice *> operands,
+                          ArrayRef<const LayoutInfoLattice *> results);
+
   bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
 
 public:
@@ -493,6 +497,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
         visitShapeCastOp(shapeCastOp, operands, results);
       })
+      .Case<xegpu::StoreMatrixOp>([&](auto storeMatrixOp) {
+        visitStoreMatrixOp(storeMatrixOp, operands, results);
+      })
       // All other ops.
       .Default([&](Operation *op) {
         for (const LayoutInfoLattice *resultInfo : results) {
@@ -1087,6 +1094,28 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
 }
 
+void LayoutInfoPropagation::visitStoreMatrixOp(
+    xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  Value operand = storeMatrix.getData();
+  unsigned index =
+      std::distance(storeMatrix.operand_begin(),
+                    llvm::find(storeMatrix->getOperands(), operand));
+
+  auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+  const int subgroupSize = uArch->getSubgroupSize();
+  SmallVector<int> instData = {1, 8};
+  LayoutInfo layout;
+  if (layoutKind == LayoutKind::InstData)
+    layout =
+        LayoutInfo(xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
+  else
+    layout =
+        getDefaultSIMTLayoutInfo(storeMatrix->getContext(), 2, subgroupSize);
+
+  propagateIfChanged(operands[index], operands[index]->meet(layout));
+}
+
 namespace {
 //===----------------------------------------------------------------------===//
 // RunLayoutInfoPropagation
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5f70831f45e97..ac0b35b43c59d 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -153,3 +153,15 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   return
 }
 }
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.mem_desc<16x64xf16>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 8]>} dense<0.000000e+00> : vector<16x16xf16>
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+  xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index b88d8e1a78a26..64e36373b7943 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -32,7 +32,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
 gpu.module @test {
 // CHECK-LABEL: func.func @dpas_i8(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} 
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
 
 func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
@@ -109,7 +109,7 @@ gpu.module @test {
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> 
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
@@ -240,7 +240,7 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] 
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
 // CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
 // CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
@@ -697,4 +697,17 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
   xegpu.store_nd %6, %arg0  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
   return
 }
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK:         %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK-NEXT:     xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+  xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+}

>From 88a49d377a71e217030c0e92e0e62e7a27f825a5 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 21 Jan 2026 12:37:08 +0000
Subject: [PATCH 2/2] Reuse scatter ops restrictions for store matrix

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 56 ++++++++++---------
 .../XeGPU/propagate-layout-inst-data.mlir     |  2 +-
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  2 +-
 3 files changed, 33 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 50bbe1f0b19fa..37e8173902a80 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -289,9 +289,9 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
 
 /// Helper to get the default layout for 2D block operations.
 template <typename Ty>
-static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
-                                               const xegpu::uArch::uArch *uArch,
-                                               unsigned packingSize) {
+static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
+                                           const xegpu::uArch::uArch *uArch,
+                                           unsigned packingSize) {
   // Expecting a 1D or 2D vector.
   assert((ty.getRank() == 1 || ty.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -309,10 +309,8 @@ static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo
-getSIMTLayoutInforForScatterIO(VectorType vectorTy,
-                               const xegpu::uArch::uArch *uArch,
-                               unsigned packingSize) {
+static LayoutInfo getSIMTLayoutInfoScatterIO(VectorType vectorTy,
+                                             const xegpu::uArch::uArch *uArch) {
   // Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -320,6 +318,7 @@ getSIMTLayoutInforForScatterIO(VectorType vectorTy,
   assert(vectorTy.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
+  const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
   if (vectorTy.getRank() == 1)
     return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
@@ -354,7 +353,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getSIMTLayoutInforForBlockIO(vectorTy, uArch, packingSize);
+  return getSIMTLayoutInfoBlockIO(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -583,7 +582,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
       prefetchLayout =
           LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
     else
-      prefetchLayout = getSIMTLayoutInforForBlockIO(
+      prefetchLayout = getSIMTLayoutInfoBlockIO(
           tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
 
     prefetch.setLayoutAttr(
@@ -836,9 +835,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else
-      storeLayout = getSIMTLayoutInforForBlockIO(
-          store.getValueType(), uArch,
-          uArchInstruction->getPackedFormatBitSize());
+      storeLayout =
+          getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
+                                   uArchInstruction->getPackedFormatBitSize());
     store.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
   }
@@ -996,8 +995,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       loadLayout =
           LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
     else
-      loadLayout = getSIMTLayoutInforForScatterIO(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
+      loadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
 
     // Mask operand should have 1D default layout.
     maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
@@ -1073,8 +1071,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
                "Expected the first dimension of 2D tensor descriptor to be "
                "equal to "
                "subgroup size.");
-      payloadLayout = getSIMTLayoutInforForScatterIO(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
+      payloadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
     }
 
     maskLayout =
@@ -1094,6 +1091,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
 }
 
+// Store matrix is a flavor of scattered store for 2D shapes.
 void LayoutInfoPropagation::visitStoreMatrixOp(
     xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
@@ -1102,16 +1100,24 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
       std::distance(storeMatrix.operand_begin(),
                     llvm::find(storeMatrix->getOperands(), operand));
 
-  auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
-  const int subgroupSize = uArch->getSubgroupSize();
-  SmallVector<int> instData = {1, 8};
+  xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
   LayoutInfo layout;
-  if (layoutKind == LayoutKind::InstData)
-    layout =
-        LayoutInfo(xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
-  else
-    layout =
-        getDefaultSIMTLayoutInfo(storeMatrix->getContext(), 2, subgroupSize);
+  if (hasParamsOfLayoutKind(anchorLayout)) {
+    layout = LayoutInfo(anchorLayout);
+  } else {
+    VectorType payloadTy = llvm::cast<VectorType>(operand.getType());
+    assert(payloadTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+    auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+    const int subgroupSize = uArch->getSubgroupSize();
+    int instLaneVec =
+        std::min(static_cast<int>(payloadTy.getShape().back()), 16);
+    SmallVector<int> instData = {subgroupSize, instLaneVec};
+    if (layoutKind == LayoutKind::InstData)
+      layout = LayoutInfo(
+          xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
+    else
+      layout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
+  }
 
   propagateIfChanged(operands[index], operands[index]->meet(layout));
 }
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index ac0b35b43c59d..0dc138f554249 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -158,7 +158,7 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
 gpu.module @test {
 // CHECK-LABEL: func.func @store_matrix(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.mem_desc<16x64xf16>) {
-// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 8]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} dense<0.000000e+00> : vector<16x16xf16>
 func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
   %cst = arith.constant dense<0.0000> : vector<16x16xf16>
   xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 64e36373b7943..bf6c5d992a47f 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -702,7 +702,7 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
 // -----
 gpu.module @test {
 // CHECK-LABEL: func.func @store_matrix(
-// CHECK:         %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK:         %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} dense<0.000000e+00> : vector<16x16xf16>
 // CHECK-NEXT:     xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
 
 func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {



More information about the Mlir-commits mailing list