[Mlir-commits] [mlir] [MLIR][XeGPU] Add simple rank-based sg layout creation (PR #172867)

Artem Kroviakov llvmlistbot at llvm.org
Fri Jan 23 02:14:57 PST 2026


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/172867

>From 592f103e379ef4991a12098f1ba8d6baf9e853c7 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 18 Dec 2025 16:25:37 +0000
Subject: [PATCH 1/5] [MLIR][XeGPU] Add simple rank-based sg layout creation

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 164 ++++++++++++++++--
 .../XeGPU/propagate-layout-subgroup.mlir      |  77 ++++++++
 2 files changed, 225 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 1341fc21e7fd4..1215f567d2cd9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -531,6 +531,55 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
   return false;
 }
 
+FailureOr<std::pair<SmallVector<int>, SmallVector<int>>>
+chooseLayout(llvm::ArrayRef<int64_t> wgShape, int64_t sgCount) {
+  const size_t rank = wgShape.size();
+
+  // Step 1. Factorize sgCount into prime factors.
+  SmallVector<int> layout;
+  int64_t temp = sgCount;
+  for (int64_t i = 2; i * i <= temp; ++i) {
+    while (temp % i == 0) {
+      layout.push_back(i);
+      temp /= i;
+    }
+  }
+  if (temp > 1)
+    layout.push_back(temp);
+
+  if (layout.size() < rank)
+    return failure();
+
+  // Step 2. Fuse two smallest factors until we have `rank` factors.
+  while (layout.size() > rank) {
+    std::sort(layout.begin(), layout.end());
+    int64_t a = layout[0];
+    int64_t b = layout[1];
+    layout.erase(layout.begin());
+    layout[0] = a * b;
+  }
+
+  SmallVector<int> data;
+  for (auto [i, dim] : llvm::enumerate(layout)) {
+    if (wgShape[i] % dim != 0)
+      return failure();
+    data.push_back(wgShape[i] / dim);
+  }
+  return std::make_pair(layout, data);
+}
+
+FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
+  // Oblivious to workitem layout, the total count matters.
+  auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
+  if (!gpuFunc)
+    return failure();
+  auto knownBlockSize = gpuFunc.getKnownBlockSize();
+  if (!knownBlockSize.has_value())
+    return failure();
+  const int flatBlockSize = llvm::product_of(knownBlockSize.value());
+  return flatBlockSize / sgSize;
+}
+
 void LayoutInfoPropagation::visitPrefetchNdOp(
     xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
@@ -746,30 +795,89 @@ void LayoutInfoPropagation::visitDpasOp(
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
       dpasBLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
-    } else {
+    } else if (layoutKind == LayoutKind::Lane) {
       dpasALayout = getSIMTLayoutInfoForDPASOperand(
           aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
       dpasBLayout = getSIMTLayoutInfoForDPASOperand(
           bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+    } else {
+      auto numSgOrErr = getNumSg(dpas, subgroupSize);
+      if (failed(numSgOrErr)) {
+        dpas.emitWarning(
+            "Unable to determine the number of subgroups for the operation.");
+        return;
+      }
+      auto layoutDataAOrErr = chooseLayout(aTy.getShape(), numSgOrErr.value());
+      if (failed(layoutDataAOrErr)) {
+        dpas.emitWarning(
+            "Unable to determine suitable subgroup layout and data for A.");
+        return;
+      }
+      auto [sgLayoutA, sgDataA] = layoutDataAOrErr.value();
+
+      dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
+          aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
+          DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
+          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+          /*lane_data =*/nullptr, /*order =*/nullptr));
+
+      auto layoutDataBOrErr = chooseLayout(bTy.getShape(), numSgOrErr.value());
+      if (failed(layoutDataBOrErr)) {
+        dpas.emitWarning(
+            "Unable to determine suitable subgroup layout and data for B.");
+        return;
+      }
+      auto [sgLayoutB, sgDataB] = layoutDataBOrErr.value();
+
+      dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
+          bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
+          DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
+          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+          /*lane_data =*/nullptr, /*order =*/nullptr));
     }
 
     if (operands.size() > 2) {
       VectorType cTy = dpas.getAccType();
+      const unsigned dataCLen = bTy.getShape().back();
+      auto supportedCLen =
+          uArchInstruction->getSupportedN(bTy.getElementType());
+      const int maxCLen =
+          xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+      if (maxCLen == -1) {
+        dpas.emitWarning(
+            "No suitable instruction multiple found for the given shape.");
+        return;
+      }
+      SmallVector<int> instDataCD = {maxALen, maxCLen};
       if (layoutKind == LayoutKind::InstData) {
-        const unsigned dataCLen = bTy.getShape().back();
-        auto supportedCLen =
-            uArchInstruction->getSupportedN(bTy.getElementType());
-        const int maxCLen = xegpu::getLargestDivisor(
-            dataCLen, ArrayRef<unsigned>(supportedCLen));
-        if (maxCLen == -1)
-          dpas.emitWarning(
-              "No suitable instruction multiple found for the given shape.");
-        SmallVector<int> instDataC = {maxALen, maxCLen};
         dpasCDLayout =
-            LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
-      } else
+            LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
+      } else if (layoutKind == LayoutKind::Lane) {
         dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
             cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+      } else {
+        auto numSgOrErr = getNumSg(dpas, subgroupSize);
+        if (failed(numSgOrErr)) {
+          dpas.emitWarning(
+              "Unable to determine the number of subgroups for the operation.");
+          return;
+        }
+        auto layoutDataAOrErr =
+            chooseLayout(cTy.getShape(), numSgOrErr.value());
+        if (failed(layoutDataAOrErr)) {
+          dpas.emitWarning(
+              "Unable to determine suitable subgroup layout and data for A.");
+          return;
+        }
+        auto [sgLayoutCD, sgDataCD] = layoutDataAOrErr.value();
+
+        dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
+            cTy.getContext(),
+            DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutCD),
+            DenseI32ArrayAttr::get(cTy.getContext(), sgDataCD),
+            /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+            /*lane_data =*/nullptr, /*order =*/nullptr));
+      }
 
       dpas.setLayoutCdAttr(
           dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
@@ -828,10 +936,34 @@ void LayoutInfoPropagation::visitStoreNdOp(
     if (layoutKind == LayoutKind::InstData)
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
-    else
-      storeLayout = getSIMTLayoutInforForBlockIO(
-          store.getValueType(), uArch,
-          uArchInstruction->getPackedFormatBitSize());
+    else if (layoutKind == LayoutKind::Lane)
+      storeLayout =
+          getSIMTLayoutInforForBlockIO(store.getValueType(), uArch,
+                                   uArchInstruction->getPackedFormatBitSize());
+    else { // LayoutKind::Subgroup
+      auto sgSize = uArch->getSubgroupSize();
+      auto numSgOrErr = getNumSg(store, sgSize);
+      if (failed(numSgOrErr)) {
+        store.emitWarning(
+            "Unable to determine the number of subgroups for the operation.");
+        return;
+      }
+      auto layoutDataOrErr =
+          chooseLayout(dataTy.getShape(), numSgOrErr.value());
+      if (failed(layoutDataOrErr)) {
+        store.emitWarning(
+            "Unable to determine suitable subgroup layout and data.");
+        return;
+      }
+      auto [sgLayout, sgData] = layoutDataOrErr.value();
+
+      storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
+          dataTy.getContext(),
+          DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
+          DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
+          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+          /*lane_data =*/nullptr, /*order =*/nullptr));
+    }
     store.setLayoutAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
   }
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 092a4cf442782..b6dfbe79e2712 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -49,3 +49,80 @@ gpu.module @test {
     return
   }
 }
+
+// -----
+gpu.module @test {
+  // CHECK-LABEL: vector_transpose
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  // CHECK-SAME: %[[ARG_1:.*]]: memref<128x256xf32>
+  gpu.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) kernel attributes
+      {known_block_size = array<i32: 1, 32, 16>} {
+    // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>>
+    // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>}>
+    // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>} :
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>> -> vector<256x128xf32>
+
+    // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
+    // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>} : vector<256x128xf32> to vector<128x256xf32>
+
+    // CHECK: xegpu.store_nd %[[TRANSPOSED]], %[[TDESC_ST]][0, 0]
+    // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>}> : vector<128x256xf32>,
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32>
+    %tdesc1 = xegpu.create_nd_tdesc %src1 : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32>
+    %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32> -> vector<256x128xf32>
+    %trans = vector.transpose %load, [1, 0] : vector<256x128xf32> to vector<128x256xf32>
+    xegpu.store_nd %trans, %tdesc1[0, 0] : vector<128x256xf32>, !xegpu.tensor_desc<128x256xf32>
+    gpu.return
+  }
+}
+
+// -----
+gpu.module @test {
+  // CHECK-LABEL: dpas
+  // CHECK-SAME: %[[A_MEMREF:.*]]: memref<128x128xf16>, %[[B_MEMREF:.*]]: memref<128x128xf16>
+  // CHECK-SAME: %[[CD_MEMREF:.*]]: memref<128x128xf32>
+  gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>, %d: memref<128x128xf32>) kernel attributes
+      {known_block_size = array<i32: 1, 64, 16>} {
+  // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[A_MEMREF]] : memref<128x128xf16> ->
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+  // CHECK: %[[A_LOADED:.*]] = xegpu.load_nd %[[TDESC_A]]
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+
+  // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[B_MEMREF]] : memref<128x128xf16> ->
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+  // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+
+  // CHECK: %[[DPAS_RES:.*]] = xegpu.dpas %[[A_LOADED]], %[[B_LOADED]]
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
+  // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
+  // CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+  // CHECK-SAME: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+
+  // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[CD_MEMREF]] : memref<128x128xf32> ->
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+  // CHECK: xegpu.store_nd %[[DPAS_RES]], %[[TDESC_ST]][0, 0]
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}> :
+  // CHECK-SAME: vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+    %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16>
+    %load_a =  xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<128x128xf16> -> vector<128x128xf16>
+    %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16>
+    %load_b =  xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<128x128xf16> -> vector<128x128xf16>
+    %dpas = xegpu.dpas %load_a, %load_b : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+    %tdesc_cd = xegpu.create_nd_tdesc %d : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32>
+    xegpu.store_nd %dpas, %tdesc_cd[0, 0] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32>
+    gpu.return
+  }
+}

>From 38531d76ad765e5963f862c6f5be54ff553466a5 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 19 Dec 2025 13:07:25 +0000
Subject: [PATCH 2/5] Adjust dpas propagation

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 40 +++++++++----------
 .../XeGPU/propagate-layout-subgroup.mlir      | 18 ++++-----
 2 files changed, 27 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 1215f567d2cd9..6b99f2cd43c71 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -800,37 +800,36 @@ void LayoutInfoPropagation::visitDpasOp(
           aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
       dpasBLayout = getSIMTLayoutInfoForDPASOperand(
           bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
-    } else {
+    } else { // Subgroup
       auto numSgOrErr = getNumSg(dpas, subgroupSize);
       if (failed(numSgOrErr)) {
         dpas.emitWarning(
             "Unable to determine the number of subgroups for the operation.");
         return;
       }
-      auto layoutDataAOrErr = chooseLayout(aTy.getShape(), numSgOrErr.value());
-      if (failed(layoutDataAOrErr)) {
+      auto layoutDataCDOrErr =
+          chooseLayout(dpas.getResultType().getShape(), numSgOrErr.value());
+      if (failed(layoutDataCDOrErr)) {
         dpas.emitWarning(
-            "Unable to determine suitable subgroup layout and data for A.");
+            "Unable to determine suitable subgroup layout and data for C/D.");
         return;
       }
-      auto [sgLayoutA, sgDataA] = layoutDataAOrErr.value();
+      auto [sgLayoutCD, sgDataCD] = layoutDataCDOrErr.value();
+      auto sgDataA = sgDataCD;
+      sgDataA[1] = aTy.getShape()[1];
+      auto sgDataB = sgDataCD;
+      sgDataB[0] = bTy.getShape()[0];
 
       dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
-          aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
+          aTy.getContext(),
+          DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutCD),
           DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
           /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
           /*lane_data =*/nullptr, /*order =*/nullptr));
 
-      auto layoutDataBOrErr = chooseLayout(bTy.getShape(), numSgOrErr.value());
-      if (failed(layoutDataBOrErr)) {
-        dpas.emitWarning(
-            "Unable to determine suitable subgroup layout and data for B.");
-        return;
-      }
-      auto [sgLayoutB, sgDataB] = layoutDataBOrErr.value();
-
       dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
-          bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
+          bTy.getContext(),
+          DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutCD),
           DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
           /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
           /*lane_data =*/nullptr, /*order =*/nullptr));
@@ -862,15 +861,14 @@ void LayoutInfoPropagation::visitDpasOp(
               "Unable to determine the number of subgroups for the operation.");
           return;
         }
-        auto layoutDataAOrErr =
-            chooseLayout(cTy.getShape(), numSgOrErr.value());
-        if (failed(layoutDataAOrErr)) {
+        auto layoutDataCDOrErr =
+            chooseLayout(dpas.getResultType().getShape(), numSgOrErr.value());
+        if (failed(layoutDataCDOrErr)) {
           dpas.emitWarning(
-              "Unable to determine suitable subgroup layout and data for A.");
+              "Unable to determine suitable subgroup layout and data for C/D.");
           return;
         }
-        auto [sgLayoutCD, sgDataCD] = layoutDataAOrErr.value();
-
+        auto [sgLayoutCD, sgDataCD] = layoutDataCDOrErr.value();
         dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
             cTy.getContext(),
             DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutCD),
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index b6dfbe79e2712..13f92f12331a2 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -89,23 +89,21 @@ gpu.module @test {
   gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>, %d: memref<128x128xf32>) kernel attributes
       {known_block_size = array<i32: 1, 64, 16>} {
   // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[A_MEMREF]] : memref<128x128xf16> ->
-  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>>
 
   // CHECK: %[[A_LOADED:.*]] = xegpu.load_nd %[[TDESC_A]]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
-  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
-  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>}>
+  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>> -> vector<128x128xf16>
 
   // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[B_MEMREF]] : memref<128x128xf16> ->
-  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>>
 
-  // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
-  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
-  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+  // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>}>
+  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>> -> vector<128x128xf16>
 
   // CHECK: %[[DPAS_RES:.*]] = xegpu.dpas %[[A_LOADED]], %[[B_LOADED]]
-  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
-  // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>,
+  // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>,
   // CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
   // CHECK-SAME: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
 

>From 062cdbec1321982183049337293582fdaf8f4a39 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 22 Jan 2026 11:44:51 +0000
Subject: [PATCH 3/5] Try to pick a common layout if possible

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 212 ++++++++++++------
 .../XeGPU/propagate-layout-subgroup.mlir      |  39 ++--
 2 files changed, 167 insertions(+), 84 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 6b99f2cd43c71..15f189a7a77d3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -531,6 +531,38 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
   return false;
 }
 
+SmallVector<std::pair<int, int>> getValidLayouts(ArrayRef<int64_t> wgShape,
+                                                 ArrayRef<int> instData,
+                                                 int64_t sgCount) {
+  SmallVector<std::pair<int, int>> candidates;
+  // Find valid multiples of instData
+  for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
+    if (sgCount % sgLayout0)
+      continue;
+    int sgLayout1 = sgCount / sgLayout0;
+    int sgData0 = wgShape[0] / sgLayout0;
+    int sgData1 = wgShape[1] / sgLayout1;
+    // Check divisibility and instruction atomic alignment
+    if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
+        (sgData0 % instData[0] || sgData1 % instData[1]))
+      continue;
+
+    candidates.emplace_back(sgLayout0, sgLayout1);
+  }
+  // Sort primarily by how balanced they are
+  // (i.e., minimize the absolute difference between the two dimensions), and
+  // secondarily by the first dimension in ascending order.
+  llvm::sort(candidates, [](const std::pair<int, int> &lhs,
+                            const std::pair<int, int> &rhs) {
+    int64_t diffLhs = std::abs(lhs.first - lhs.second);
+    int64_t diffRhs = std::abs(rhs.first - rhs.second);
+    if (diffLhs != diffRhs)
+      return diffLhs < diffRhs;
+    return lhs.first < rhs.first;
+  });
+  return candidates;
+}
+
 FailureOr<std::pair<SmallVector<int>, SmallVector<int>>>
 chooseLayout(llvm::ArrayRef<int64_t> wgShape, int64_t sgCount) {
   const size_t rank = wgShape.size();
@@ -763,6 +795,10 @@ void LayoutInfoPropagation::visitDpasOp(
 
     VectorType aTy = dpas.getLhsType();
     VectorType bTy = dpas.getRhsType();
+    VectorType cTy;
+    const bool hasAcc = operands.size() > 2;
+    if (hasAcc)
+      cTy = dpas.getAccType();
 
     auto uArch = getUArch(getChipStr(dpas).value_or(""));
     const int subgroupSize = uArch->getSubgroupSize();
@@ -789,17 +825,38 @@ void LayoutInfoPropagation::visitDpasOp(
           "No suitable instruction multiple found for the given shape.");
     SmallVector<int> instDataA = {maxALen, subgroupSize};
     SmallVector<int> instDataB = {subgroupSize, maxBLen};
-
+    SmallVector<int> instDataCD;
+    if (hasAcc) {
+      const unsigned dataCLen = bTy.getShape().back();
+      auto supportedCLen =
+          uArchInstruction->getSupportedN(cTy.getElementType());
+      const int maxCLen =
+          xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+      if (maxCLen == -1) {
+        dpas.emitWarning(
+            "No suitable instruction multiple found for the given shape.");
+        return;
+      }
+      instDataCD = {maxALen, maxCLen};
+    }
     if (layoutKind == LayoutKind::InstData) {
       dpasALayout =
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
       dpasBLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+      if (hasAcc) {
+        dpasCDLayout =
+            LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
+      }
     } else if (layoutKind == LayoutKind::Lane) {
       dpasALayout = getSIMTLayoutInfoForDPASOperand(
           aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
       dpasBLayout = getSIMTLayoutInfoForDPASOperand(
           bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+      if (hasAcc) {
+        dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
+            cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+      }
     } else { // Subgroup
       auto numSgOrErr = getNumSg(dpas, subgroupSize);
       if (failed(numSgOrErr)) {
@@ -807,83 +864,108 @@ void LayoutInfoPropagation::visitDpasOp(
             "Unable to determine the number of subgroups for the operation.");
         return;
       }
-      auto layoutDataCDOrErr =
-          chooseLayout(dpas.getResultType().getShape(), numSgOrErr.value());
-      if (failed(layoutDataCDOrErr)) {
+      // Step 1. Retrieve D layout. Get all valid layouts for A and B
+      LayoutInfo layoutD = results[0]->getValue();
+      SmallVector<int> sgLayoutD = layoutD.getSgLayout();
+      assert(!sgLayoutD.empty() && "Expected layout for DPAS result.");
+      auto layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+
+      auto layoutsA =
+          getValidLayouts(aTy.getShape(), instDataA, numSgOrErr.value());
+      auto layoutsB =
+          getValidLayouts(bTy.getShape(), instDataB, numSgOrErr.value());
+      SmallVector<std::pair<int, int>> layoutsC;
+      if (hasAcc)
+        layoutsC =
+            getValidLayouts(cTy.getShape(), instDataCD, numSgOrErr.value());
+
+      if (layoutsA.empty() || layoutsB.empty() ||
+          (hasAcc && layoutsC.empty())) {
         dpas.emitWarning(
-            "Unable to determine suitable subgroup layout and data for C/D.");
+            "Unable to determine suitable subgroup layout for A/B/C matrices.");
         return;
       }
-      auto [sgLayoutCD, sgDataCD] = layoutDataCDOrErr.value();
-      auto sgDataA = sgDataCD;
-      sgDataA[1] = aTy.getShape()[1];
-      auto sgDataB = sgDataCD;
-      sgDataB[0] = bTy.getShape()[0];
+
+      // Step 2. Find common layouts.
+      // Ideally, we want to find given layout D in all A, B and C candidates.
+
+      // Ensure D layout matches one of C layouts.
+      if (hasAcc && llvm::find(layoutsC, layoutDVal) == layoutsC.end()) {
+        dpas.emitWarning("Subgroup layout for D does not match any valid C "
+                         "subgroup layout.");
+        return;
+      }
+      // The best pick is layout D. If not found, we will pick any common layout
+      // between A and B.
+      std::optional<std::pair<int, int>> bestPick;
+      llvm::DenseSet<std::pair<int, int>> setA(layoutsA.begin(),
+                                               layoutsA.end());
+      SmallVector<std::pair<int, int>> common;
+      for (auto &l : layoutsB) {
+        if (setA.contains(l)) {
+          if (l == layoutDVal) {
+            bestPick = l;
+            break;
+          }
+          common.push_back(l);
+        }
+      }
+      // Step 3. The best pick either matches D or is any common layout between
+      // A and B. If no common layout, warn and pick any valid layout.
+      SmallVector<int> sgLayoutA;
+      SmallVector<int> sgLayoutB;
+      if (!bestPick && !common.empty())
+        bestPick = common[0];
+      if (bestPick) {
+        sgLayoutA = {bestPick->first, bestPick->second};
+        sgLayoutB = sgLayoutA;
+      } else {
+        dpas.emitWarning(
+            "Unable to find common subgroup layout for matrices matching "
+            "layout of result. Picking any valid layout.");
+        sgLayoutA = {layoutsA[0].first, layoutsA[0].second};
+        sgLayoutB = {layoutsB[0].first, layoutsB[0].second};
+      }
+      SmallVector<int> sgDataA = {
+          static_cast<int>(aTy.getShape()[0]) / sgLayoutA[0],
+          static_cast<int>(aTy.getShape()[1]) / sgLayoutA[1]};
+      SmallVector<int> sgDataB = {
+          static_cast<int>(bTy.getShape()[0]) / sgLayoutB[0],
+          static_cast<int>(bTy.getShape()[1]) / sgLayoutB[1]};
+      SmallVector<int> sgDataC;
+      if (hasAcc)
+        sgDataC = {static_cast<int>(dpas.getResultType().getShape()[0]) /
+                       sgLayoutD[0],
+                   static_cast<int>(dpas.getResultType().getShape()[1]) /
+                       sgLayoutD[1]};
 
       dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
-          aTy.getContext(),
-          DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutCD),
+          aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
           DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
           /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
           /*lane_data =*/nullptr, /*order =*/nullptr));
 
       dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
-          bTy.getContext(),
-          DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutCD),
+          bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
           DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
           /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
           /*lane_data =*/nullptr, /*order =*/nullptr));
-    }
-
-    if (operands.size() > 2) {
-      VectorType cTy = dpas.getAccType();
-      const unsigned dataCLen = bTy.getShape().back();
-      auto supportedCLen =
-          uArchInstruction->getSupportedN(bTy.getElementType());
-      const int maxCLen =
-          xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
-      if (maxCLen == -1) {
-        dpas.emitWarning(
-            "No suitable instruction multiple found for the given shape.");
-        return;
-      }
-      SmallVector<int> instDataCD = {maxALen, maxCLen};
-      if (layoutKind == LayoutKind::InstData) {
-        dpasCDLayout =
-            LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
-      } else if (layoutKind == LayoutKind::Lane) {
-        dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
-            cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
-      } else {
-        auto numSgOrErr = getNumSg(dpas, subgroupSize);
-        if (failed(numSgOrErr)) {
-          dpas.emitWarning(
-              "Unable to determine the number of subgroups for the operation.");
-          return;
-        }
-        auto layoutDataCDOrErr =
-            chooseLayout(dpas.getResultType().getShape(), numSgOrErr.value());
-        if (failed(layoutDataCDOrErr)) {
-          dpas.emitWarning(
-              "Unable to determine suitable subgroup layout and data for C/D.");
-          return;
-        }
-        auto [sgLayoutCD, sgDataCD] = layoutDataCDOrErr.value();
+      if (hasAcc) {
         dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
             cTy.getContext(),
-            DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutCD),
-            DenseI32ArrayAttr::get(cTy.getContext(), sgDataCD),
+            DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutD),
+            DenseI32ArrayAttr::get(cTy.getContext(), sgDataC),
             /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
             /*lane_data =*/nullptr, /*order =*/nullptr));
       }
-
-      dpas.setLayoutCdAttr(
-          dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
     }
     dpas.setLayoutAAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
     dpas.setLayoutBAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
+    if (hasAcc)
+      dpas.setLayoutCdAttr(
+          dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
   }
 
   propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
@@ -935,9 +1017,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else if (layoutKind == LayoutKind::Lane)
-      storeLayout =
-          getSIMTLayoutInforForBlockIO(store.getValueType(), uArch,
-                                   uArchInstruction->getPackedFormatBitSize());
+      storeLayout = getSIMTLayoutInforForBlockIO(
+          store.getValueType(), uArch,
+          uArchInstruction->getPackedFormatBitSize());
     else { // LayoutKind::Subgroup
       auto sgSize = uArch->getSubgroupSize();
       auto numSgOrErr = getNumSg(store, sgSize);
@@ -946,15 +1028,17 @@ void LayoutInfoPropagation::visitStoreNdOp(
             "Unable to determine the number of subgroups for the operation.");
         return;
       }
-      auto layoutDataOrErr =
-          chooseLayout(dataTy.getShape(), numSgOrErr.value());
-      if (failed(layoutDataOrErr)) {
+      auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
+                                       instData, numSgOrErr.value());
+      if (sgLayouts.empty()) {
         store.emitWarning(
-            "Unable to determine suitable subgroup layout and data.");
+            "Unable to determine suitable subgroup layout for store value.");
         return;
       }
-      auto [sgLayout, sgData] = layoutDataOrErr.value();
-
+      SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
+      SmallVector<int> sgData = {
+          static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
+          static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
       storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
           dataTy.getContext(),
           DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 13f92f12331a2..a0adb731605de 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -58,20 +58,19 @@ gpu.module @test {
   gpu.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) kernel attributes
       {known_block_size = array<i32: 1, 32, 16>} {
     // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>>
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
     // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
-    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>>
 
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>}>
-    // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>} :
-    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>> -> vector<256x128xf32>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}> :
+    // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>> -> vector<256x128xf32>
 
     // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
-    // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>} : vector<256x128xf32> to vector<128x256xf32>
+    // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<256x128xf32> to vector<128x256xf32>
 
     // CHECK: xegpu.store_nd %[[TRANSPOSED]], %[[TDESC_ST]][0, 0]
-    // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>}> : vector<128x256xf32>,
-    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+    // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>}> : vector<128x256xf32>,
+    // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>>
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32>
     %tdesc1 = xegpu.create_nd_tdesc %src1 : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32>
     %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32> -> vector<256x128xf32>
@@ -89,30 +88,30 @@ gpu.module @test {
   gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>, %d: memref<128x128xf32>) kernel attributes
       {known_block_size = array<i32: 1, 64, 16>} {
   // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[A_MEMREF]] : memref<128x128xf16> ->
-  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>>
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
 
   // CHECK: %[[A_LOADED:.*]] = xegpu.load_nd %[[TDESC_A]]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>}>
-  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>> -> vector<128x128xf16>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}>
+  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf16>
 
   // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[B_MEMREF]] : memref<128x128xf16> ->
-  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>>
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
 
-  // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>}>
-  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>> -> vector<128x128xf16>
+  // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}>
+  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf16>
 
   // CHECK: %[[DPAS_RES:.*]] = xegpu.dpas %[[A_LOADED]], %[[B_LOADED]]
-  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>,
-  // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>,
-  // CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>,
+  // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>,
+  // CHECK-SAME: layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} :
   // CHECK-SAME: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
 
   // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[CD_MEMREF]] : memref<128x128xf32> ->
-  // CHECK-SAME: !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
 
   // CHECK: xegpu.store_nd %[[DPAS_RES]], %[[TDESC_ST]][0, 0]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}> :
-  // CHECK-SAME: vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> :
+  // CHECK-SAME: vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
 
     %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16>
     %load_a =  xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<128x128xf16> -> vector<128x128xf16>

>From 763f571de2d262a940097dc2f4301fec51cb38a6 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 22 Jan 2026 18:35:16 +0000
Subject: [PATCH 4/5] Prefer D layout, if not possible pick common for A, B, C,
 otherwise fail

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 53 ++++++++++---------
 1 file changed, 27 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 15f189a7a77d3..87fea55b2b9d6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -864,7 +864,9 @@ void LayoutInfoPropagation::visitDpasOp(
             "Unable to determine the number of subgroups for the operation.");
         return;
       }
-      // Step 1. Retrieve D layout. Get all valid layouts for A and B
+
+      // Step 1. Get all valid layouts for A, B, and C operands.
+      // All operands must have at least one valid subgroup layout.
       LayoutInfo layoutD = results[0]->getValue();
       SmallVector<int> sgLayoutD = layoutD.getSgLayout();
       assert(!sgLayoutD.empty() && "Expected layout for DPAS result.");
@@ -886,45 +888,44 @@ void LayoutInfoPropagation::visitDpasOp(
         return;
       }
 
-      // Step 2. Find common layouts.
-      // Ideally, we want to find given layout D in all A, B and C candidates.
-
-      // Ensure D layout matches one of C layouts.
-      if (hasAcc && llvm::find(layoutsC, layoutDVal) == layoutsC.end()) {
-        dpas.emitWarning("Subgroup layout for D does not match any valid C "
-                         "subgroup layout.");
-        return;
-      }
-      // The best pick is layout D. If not found, we will pick any common layout
-      // between A and B.
-      std::optional<std::pair<int, int>> bestPick;
+      // Step 2. If the result D layout can be reused for all operands, that
+      // layout is chosen. Otherwise, pick the most balanced subgroup layout
+      // that is valid for A, B and C (if present) operands
       llvm::DenseSet<std::pair<int, int>> setA(layoutsA.begin(),
                                                layoutsA.end());
-      SmallVector<std::pair<int, int>> common;
+      llvm::DenseSet<std::pair<int, int>> setC;
+      if (hasAcc)
+        setC = llvm::DenseSet<std::pair<int, int>>(layoutsC.begin(),
+                                                   layoutsC.end());
+      std::optional<std::pair<int, int>> bestPick;
       for (auto &l : layoutsB) {
+        // Is in valid A layouts
         if (setA.contains(l)) {
+          // Is in valid C layouts
+          if (hasAcc && !setC.contains(l))
+            continue;
+          // Is in (A and B and C) and matches D -> best pick
           if (l == layoutDVal) {
             bestPick = l;
             break;
           }
-          common.push_back(l);
+          // Is in (A and B and C), balanced layout comes first
+          if (!bestPick)
+            bestPick = l;
         }
       }
-      // Step 3. The best pick either matches D or is any common layout between
-      // A and B. If no common layout, warn and pick any valid layout.
+      // Step 3. If there is no subgroup layout compatible with A, B and C (if
+      // present) operands, we fail.
       SmallVector<int> sgLayoutA;
       SmallVector<int> sgLayoutB;
-      if (!bestPick && !common.empty())
-        bestPick = common[0];
+      SmallVector<int> sgLayoutC;
       if (bestPick) {
         sgLayoutA = {bestPick->first, bestPick->second};
         sgLayoutB = sgLayoutA;
+        sgLayoutC = sgLayoutA;
       } else {
-        dpas.emitWarning(
-            "Unable to find common subgroup layout for matrices matching "
-            "layout of result. Picking any valid layout.");
-        sgLayoutA = {layoutsA[0].first, layoutsA[0].second};
-        sgLayoutB = {layoutsB[0].first, layoutsB[0].second};
+        dpas.emitWarning("Unable to find common subgroup layout for matrices.");
+        return;
       }
       SmallVector<int> sgDataA = {
           static_cast<int>(aTy.getShape()[0]) / sgLayoutA[0],
@@ -935,9 +936,9 @@ void LayoutInfoPropagation::visitDpasOp(
       SmallVector<int> sgDataC;
       if (hasAcc)
         sgDataC = {static_cast<int>(dpas.getResultType().getShape()[0]) /
-                       sgLayoutD[0],
+                       sgLayoutC[0],
                    static_cast<int>(dpas.getResultType().getShape()[1]) /
-                       sgLayoutD[1]};
+                       sgLayoutC[1]};
 
       dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
           aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),

>From 131a4acf49ce50ae0b999ab69482e09a1ef908f8 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 23 Jan 2026 10:02:11 +0000
Subject: [PATCH 5/5] Cleanup

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 84 +++++--------------
 1 file changed, 21 insertions(+), 63 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 87fea55b2b9d6..ae59a2e92759c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -531,22 +531,26 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
   return false;
 }
 
+// This function returns all layouts for the given sgCount, whose sgData:
+// 1. Evenly divides the wgShape.
+// 2. Is a multiple of instData.
+// Example:
+//   wgShape = [128, 64], instData = [8, 16], sgCount = 32
+// Returns layouts:
+//   [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
 SmallVector<std::pair<int, int>> getValidLayouts(ArrayRef<int64_t> wgShape,
                                                  ArrayRef<int> instData,
                                                  int64_t sgCount) {
   SmallVector<std::pair<int, int>> candidates;
-  // Find valid multiples of instData
   for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
     if (sgCount % sgLayout0)
       continue;
     int sgLayout1 = sgCount / sgLayout0;
     int sgData0 = wgShape[0] / sgLayout0;
     int sgData1 = wgShape[1] / sgLayout1;
-    // Check divisibility and instruction atomic alignment
     if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
         (sgData0 % instData[0] || sgData1 % instData[1]))
       continue;
-
     candidates.emplace_back(sgLayout0, sgLayout1);
   }
   // Sort primarily by how balanced they are
@@ -554,8 +558,8 @@ SmallVector<std::pair<int, int>> getValidLayouts(ArrayRef<int64_t> wgShape,
   // secondarily by the first dimension in ascending order.
   llvm::sort(candidates, [](const std::pair<int, int> &lhs,
                             const std::pair<int, int> &rhs) {
-    int64_t diffLhs = std::abs(lhs.first - lhs.second);
-    int64_t diffRhs = std::abs(rhs.first - rhs.second);
+    int diffLhs = std::abs(lhs.first - lhs.second);
+    int diffRhs = std::abs(rhs.first - rhs.second);
     if (diffLhs != diffRhs)
       return diffLhs < diffRhs;
     return lhs.first < rhs.first;
@@ -563,43 +567,6 @@ SmallVector<std::pair<int, int>> getValidLayouts(ArrayRef<int64_t> wgShape,
   return candidates;
 }
 
-FailureOr<std::pair<SmallVector<int>, SmallVector<int>>>
-chooseLayout(llvm::ArrayRef<int64_t> wgShape, int64_t sgCount) {
-  const size_t rank = wgShape.size();
-
-  // Step 1. Factorize sgCount into prime factors.
-  SmallVector<int> layout;
-  int64_t temp = sgCount;
-  for (int64_t i = 2; i * i <= temp; ++i) {
-    while (temp % i == 0) {
-      layout.push_back(i);
-      temp /= i;
-    }
-  }
-  if (temp > 1)
-    layout.push_back(temp);
-
-  if (layout.size() < rank)
-    return failure();
-
-  // Step 2. Fuse two smallest factors until we have `rank` factors.
-  while (layout.size() > rank) {
-    std::sort(layout.begin(), layout.end());
-    int64_t a = layout[0];
-    int64_t b = layout[1];
-    layout.erase(layout.begin());
-    layout[0] = a * b;
-  }
-
-  SmallVector<int> data;
-  for (auto [i, dim] : llvm::enumerate(layout)) {
-    if (wgShape[i] % dim != 0)
-      return failure();
-    data.push_back(wgShape[i] / dim);
-  }
-  return std::make_pair(layout, data);
-}
-
 FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
   // Oblivious to workitem layout, the total count matters.
   auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
@@ -790,9 +757,7 @@ void LayoutInfoPropagation::visitDpasOp(
     dpasALayout = LayoutInfo(anchorLayoutA);
     dpasBLayout = LayoutInfo(anchorLayoutB);
     dpasCDLayout = LayoutInfo(anchorLayoutCD);
-
   } else {
-
     VectorType aTy = dpas.getLhsType();
     VectorType bTy = dpas.getRhsType();
     VectorType cTy;
@@ -899,9 +864,7 @@ void LayoutInfoPropagation::visitDpasOp(
                                                    layoutsC.end());
       std::optional<std::pair<int, int>> bestPick;
       for (auto &l : layoutsB) {
-        // Is in valid A layouts
         if (setA.contains(l)) {
-          // Is in valid C layouts
           if (hasAcc && !setC.contains(l))
             continue;
           // Is in (A and B and C) and matches D -> best pick
@@ -916,45 +879,40 @@ void LayoutInfoPropagation::visitDpasOp(
       }
       // Step 3. If there is no subgroup layout compatible with A, B and C (if
       // present) operands, we fail.
-      SmallVector<int> sgLayoutA;
-      SmallVector<int> sgLayoutB;
-      SmallVector<int> sgLayoutC;
+      SmallVector<int> sgLayout;
       if (bestPick) {
-        sgLayoutA = {bestPick->first, bestPick->second};
-        sgLayoutB = sgLayoutA;
-        sgLayoutC = sgLayoutA;
+        sgLayout = {bestPick->first, bestPick->second};
       } else {
         dpas.emitWarning("Unable to find common subgroup layout for matrices.");
         return;
       }
       SmallVector<int> sgDataA = {
-          static_cast<int>(aTy.getShape()[0]) / sgLayoutA[0],
-          static_cast<int>(aTy.getShape()[1]) / sgLayoutA[1]};
+          static_cast<int>(aTy.getShape()[0]) / sgLayout[0],
+          static_cast<int>(aTy.getShape()[1]) / sgLayout[1]};
       SmallVector<int> sgDataB = {
-          static_cast<int>(bTy.getShape()[0]) / sgLayoutB[0],
-          static_cast<int>(bTy.getShape()[1]) / sgLayoutB[1]};
+          static_cast<int>(bTy.getShape()[0]) / sgLayout[0],
+          static_cast<int>(bTy.getShape()[1]) / sgLayout[1]};
       SmallVector<int> sgDataC;
       if (hasAcc)
-        sgDataC = {static_cast<int>(dpas.getResultType().getShape()[0]) /
-                       sgLayoutC[0],
-                   static_cast<int>(dpas.getResultType().getShape()[1]) /
-                       sgLayoutC[1]};
+        sgDataC = {
+            static_cast<int>(dpas.getResultType().getShape()[0]) / sgLayout[0],
+            static_cast<int>(dpas.getResultType().getShape()[1]) / sgLayout[1]};
 
       dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
-          aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
+          aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayout),
           DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
           /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
           /*lane_data =*/nullptr, /*order =*/nullptr));
 
       dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
-          bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
+          bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayout),
           DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
           /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
           /*lane_data =*/nullptr, /*order =*/nullptr));
       if (hasAcc) {
         dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
             cTy.getContext(),
-            DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutD),
+            DenseI32ArrayAttr::get(cTy.getContext(), sgLayout),
             DenseI32ArrayAttr::get(cTy.getContext(), sgDataC),
             /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
             /*lane_data =*/nullptr, /*order =*/nullptr));



More information about the Mlir-commits mailing list