[Mlir-commits] [mlir] 2357408 - [MLIR][XeGPU] Add layout propagation for `xegpu.store_matrix` (#174952)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 23 08:08:35 PST 2026


Author: Artem Kroviakov
Date: 2026-01-23T17:08:30+01:00
New Revision: 2357408dd2c7700d8153f26bb081bb1065287c41

URL: https://github.com/llvm/llvm-project/commit/2357408dd2c7700d8153f26bb081bb1065287c41
DIFF: https://github.com/llvm/llvm-project/commit/2357408dd2c7700d8153f26bb081bb1065287c41.diff

LOG: [MLIR][XeGPU] Add layout propagation for `xegpu.store_matrix` (#174952)

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
    mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
    mlir/test/Dialect/XeGPU/propagate-layout.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index ae59a2e92759c..e8eadb6de5b30 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -289,9 +289,9 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
 
 /// Helper to get the default layout for 2D block operations.
 template <typename Ty>
-static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
-                                               const xegpu::uArch::uArch *uArch,
-                                               unsigned packingSize) {
+static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
+                                           const xegpu::uArch::uArch *uArch,
+                                           unsigned packingSize) {
   // Expecting a 1D or 2D vector.
   assert((ty.getRank() == 1 || ty.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -309,10 +309,8 @@ static LayoutInfo getSIMTLayoutInforForBlockIO(Ty ty,
 }
 
 /// Helper to get the default layout for a vector type.
-static LayoutInfo
-getSIMTLayoutInforForScatterIO(VectorType vectorTy,
-                               const xegpu::uArch::uArch *uArch,
-                               unsigned packingSize) {
+static LayoutInfo getSIMTLayoutInfoScatterIO(VectorType vectorTy,
+                                             const xegpu::uArch::uArch *uArch) {
   // Expecting a 1D or 2D vector.
   assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
          "Expected 1D or 2D vector.");
@@ -320,6 +318,7 @@ getSIMTLayoutInforForScatterIO(VectorType vectorTy,
   assert(vectorTy.getElementType().isIntOrFloat() &&
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
+  const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
   if (vectorTy.getRank() == 1)
     return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
@@ -354,7 +353,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getSIMTLayoutInforForBlockIO(vectorTy, uArch, packingSize);
+  return getSIMTLayoutInfoBlockIO(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -420,6 +419,10 @@ class LayoutInfoPropagation
                         ArrayRef<LayoutInfoLattice *> operands,
                         ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
+                          ArrayRef<LayoutInfoLattice *> operands,
+                          ArrayRef<const LayoutInfoLattice *> results);
+
   bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
 
 public:
@@ -493,6 +496,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
         visitShapeCastOp(shapeCastOp, operands, results);
       })
+      .Case<xegpu::StoreMatrixOp>([&](auto storeMatrixOp) {
+        visitStoreMatrixOp(storeMatrixOp, operands, results);
+      })
       // All other ops.
       .Default([&](Operation *op) {
         for (const LayoutInfoLattice *resultInfo : results) {
@@ -624,7 +630,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
       prefetchLayout =
           LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
     else
-      prefetchLayout = getSIMTLayoutInforForBlockIO(
+      prefetchLayout = getSIMTLayoutInfoBlockIO(
           tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
 
     prefetch.setLayoutAttr(
@@ -976,9 +982,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else if (layoutKind == LayoutKind::Lane)
-      storeLayout = getSIMTLayoutInforForBlockIO(
-          store.getValueType(), uArch,
-          uArchInstruction->getPackedFormatBitSize());
+      storeLayout =
+          getSIMTLayoutInfoBlockIO(store.getValueType(), uArch,
+                                   uArchInstruction->getPackedFormatBitSize());
     else { // LayoutKind::Subgroup
       auto sgSize = uArch->getSubgroupSize();
       auto numSgOrErr = getNumSg(store, sgSize);
@@ -1162,8 +1168,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       loadLayout =
           LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
     else
-      loadLayout = getSIMTLayoutInforForScatterIO(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
+      loadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
 
     // Mask operand should have 1D default layout.
     maskLayout = getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
@@ -1239,8 +1244,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
                "Expected the first dimension of 2D tensor descriptor to be "
                "equal to "
                "subgroup size.");
-      payloadLayout = getSIMTLayoutInforForScatterIO(
-          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize());
+      payloadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
     }
 
     maskLayout =
@@ -1260,6 +1264,34 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
 }
 
+// Store matrix is a flavor of scattered store for 2D shapes.
+void LayoutInfoPropagation::visitStoreMatrixOp(
+    xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  Value operand = storeMatrix.getData();
+  unsigned index =
+      std::distance(storeMatrix.operand_begin(),
+                    llvm::find(storeMatrix->getOperands(), operand));
+
+  xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
+  LayoutInfo layout;
+  if (hasParamsOfLayoutKind(anchorLayout)) {
+    layout = LayoutInfo(anchorLayout);
+  } else {
+    VectorType payloadTy = llvm::cast<VectorType>(operand.getType());
+    assert(payloadTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+    auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+    SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+    if (layoutKind == LayoutKind::InstData)
+      layout = LayoutInfo(
+          xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
+    else
+      layout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
+  }
+
+  propagateIfChanged(operands[index], operands[index]->meet(layout));
+}
+
 namespace {
 //===----------------------------------------------------------------------===//
 // RunLayoutInfoPropagation

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5f70831f45e97..aef9d7fc9e03a 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -153,3 +153,15 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   return
 }
 }
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.mem_desc<16x64xf16>) {
+// CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<0.000000e+00> : vector<16x16xf16>
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+  xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+}

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index b88d8e1a78a26..bf6c5d992a47f 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -32,7 +32,7 @@ func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: me
 gpu.module @test {
 // CHECK-LABEL: func.func @dpas_i8(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<8x32xi8>, %[[ARG1:[0-9a-zA-Z]+]]: vector<32x16xi8>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xi32>) {
-// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} 
+// CHECK: %[[T0:.*]] = xegpu.dpas %[[ARG0]], %[[ARG1]] {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
 
 func.func @dpas_i8(%arg0: vector<8x32xi8>, %arg1: vector<32x16xi8>, %arg2: memref<8x16xi32>) {
   %c0 = arith.constant 0 : index
@@ -109,7 +109,7 @@ gpu.module @test {
 // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> 
+// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
 // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
@@ -240,7 +240,7 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] 
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
 // CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
 // CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
 // CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
@@ -697,4 +697,17 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
   xegpu.store_nd %6, %arg0  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
   return
 }
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @store_matrix(
+// CHECK:         %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK-NEXT:     xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+
+func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  %cst = arith.constant dense<0.0000> : vector<16x16xf16>
+  xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+}


        


More information about the Mlir-commits mailing list