[Mlir-commits] [mlir] [MLIR][XeGPU] Add simple rank-based sg layout creation (PR #172867)

Jianhui Li llvmlistbot at llvm.org
Thu Jan 22 11:13:25 PST 2026


================
@@ -740,44 +825,148 @@ void LayoutInfoPropagation::visitDpasOp(
           "No suitable instruction multiple found for the given shape.");
     SmallVector<int> instDataA = {maxALen, subgroupSize};
     SmallVector<int> instDataB = {subgroupSize, maxBLen};
-
+    SmallVector<int> instDataCD;
+    if (hasAcc) {
+      const unsigned dataCLen = bTy.getShape().back();
+      auto supportedCLen =
+          uArchInstruction->getSupportedN(cTy.getElementType());
+      const int maxCLen =
+          xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+      if (maxCLen == -1) {
+        dpas.emitWarning(
+            "No suitable instruction multiple found for the given shape.");
+        return;
+      }
+      instDataCD = {maxALen, maxCLen};
+    }
     if (layoutKind == LayoutKind::InstData) {
       dpasALayout =
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
       dpasBLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
-    } else {
+      if (hasAcc) {
+        dpasCDLayout =
+            LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
+      }
+    } else if (layoutKind == LayoutKind::Lane) {
       dpasALayout = getSIMTLayoutInfoForDPASOperand(
           aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
       dpasBLayout = getSIMTLayoutInfoForDPASOperand(
           bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
-    }
-
-    if (operands.size() > 2) {
-      VectorType cTy = dpas.getAccType();
-      if (layoutKind == LayoutKind::InstData) {
-        const unsigned dataCLen = bTy.getShape().back();
-        auto supportedCLen =
-            uArchInstruction->getSupportedN(bTy.getElementType());
-        const int maxCLen = xegpu::getLargestDivisor(
-            dataCLen, ArrayRef<unsigned>(supportedCLen));
-        if (maxCLen == -1)
-          dpas.emitWarning(
-              "No suitable instruction multiple found for the given shape.");
-        SmallVector<int> instDataC = {maxALen, maxCLen};
-        dpasCDLayout =
-            LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
-      } else
+      if (hasAcc) {
         dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
             cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+      }
+    } else { // Subgroup
+      auto numSgOrErr = getNumSg(dpas, subgroupSize);
+      if (failed(numSgOrErr)) {
+        dpas.emitWarning(
+            "Unable to determine the number of subgroups for the operation.");
+        return;
+      }
 
-      dpas.setLayoutCdAttr(
-          dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
+      // Step 1. Get all valid layouts for A, B, and C operands.
+      // All operands must have at least one valid subgroup layout.
+      LayoutInfo layoutD = results[0]->getValue();
+      SmallVector<int> sgLayoutD = layoutD.getSgLayout();
+      assert(!sgLayoutD.empty() && "Expected layout for DPAS result.");
+      auto layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+
+      auto layoutsA =
+          getValidLayouts(aTy.getShape(), instDataA, numSgOrErr.value());
+      auto layoutsB =
+          getValidLayouts(bTy.getShape(), instDataB, numSgOrErr.value());
+      SmallVector<std::pair<int, int>> layoutsC;
+      if (hasAcc)
+        layoutsC =
+            getValidLayouts(cTy.getShape(), instDataCD, numSgOrErr.value());
+
+      if (layoutsA.empty() || layoutsB.empty() ||
+          (hasAcc && layoutsC.empty())) {
+        dpas.emitWarning(
+            "Unable to determine suitable subgroup layout for A/B/C matrices.");
+        return;
+      }
+
+      // Step 2. If the result D layout can be reused for all operands, that
+      // layout is chosen. Otherwise, pick the most balanced subgroup layout
+      // that is valid for A, B and C (if present) operands
+      llvm::DenseSet<std::pair<int, int>> setA(layoutsA.begin(),
+                                               layoutsA.end());
+      llvm::DenseSet<std::pair<int, int>> setC;
+      if (hasAcc)
+        setC = llvm::DenseSet<std::pair<int, int>>(layoutsC.begin(),
+                                                   layoutsC.end());
+      std::optional<std::pair<int, int>> bestPick;
+      for (auto &l : layoutsB) {
+        // Is in valid A layouts
+        if (setA.contains(l)) {
+          // Is in valid C layouts
+          if (hasAcc && !setC.contains(l))
+            continue;
+          // Is in (A and B and C) and matches D -> best pick
+          if (l == layoutDVal) {
+            bestPick = l;
+            break;
+          }
+          // Is in (A and B and C), balanced layout comes first
+          if (!bestPick)
+            bestPick = l;
+        }
+      }
+      // Step 3. If there is no subgroup layout compatible with A, B and C (if
+      // present) operands, we fail.
+      SmallVector<int> sgLayoutA;
+      SmallVector<int> sgLayoutB;
+      SmallVector<int> sgLayoutC;
+      if (bestPick) {
+        sgLayoutA = {bestPick->first, bestPick->second};
+        sgLayoutB = sgLayoutA;
+        sgLayoutC = sgLayoutA;
+      } else {
+        dpas.emitWarning("Unable to find common subgroup layout for matrices.");
+        return;
+      }
+      SmallVector<int> sgDataA = {
+          static_cast<int>(aTy.getShape()[0]) / sgLayoutA[0],
+          static_cast<int>(aTy.getShape()[1]) / sgLayoutA[1]};
+      SmallVector<int> sgDataB = {
+          static_cast<int>(bTy.getShape()[0]) / sgLayoutB[0],
+          static_cast<int>(bTy.getShape()[1]) / sgLayoutB[1]};
+      SmallVector<int> sgDataC;
+      if (hasAcc)
+        sgDataC = {static_cast<int>(dpas.getResultType().getShape()[0]) /
+                       sgLayoutC[0],
+                   static_cast<int>(dpas.getResultType().getShape()[1]) /
+                       sgLayoutC[1]};
+
+      dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
+          aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
+          DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
+          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+          /*lane_data =*/nullptr, /*order =*/nullptr));
+
+      dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
+          bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
+          DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
+          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+          /*lane_data =*/nullptr, /*order =*/nullptr));
+      if (hasAcc) {
+        dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
+            cTy.getContext(),
+            DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutD),
+            DenseI32ArrayAttr::get(cTy.getContext(), sgDataC),
+            /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+            /*lane_data =*/nullptr, /*order =*/nullptr));
+      }
     }
     dpas.setLayoutAAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
     dpas.setLayoutBAttr(
         dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
+    if (hasAcc)
+      dpas.setLayoutCdAttr(
----------------
Jianhui-Li wrote:

Just comment for future PR: Note that once we have layout conflict mechanism set up, dpas should setLayout for C/D always. Since the layout conflict detection needs to read DPAS's output layout and compare with its use.  

https://github.com/llvm/llvm-project/pull/172867


More information about the Mlir-commits mailing list