[Mlir-commits] [mlir] [MLIR][XeGPU] Add simple rank-based sg layout creation (PR #172867)
Jianhui Li
llvmlistbot at llvm.org
Thu Jan 22 11:13:25 PST 2026
================
@@ -740,44 +825,148 @@ void LayoutInfoPropagation::visitDpasOp(
"No suitable instruction multiple found for the given shape.");
SmallVector<int> instDataA = {maxALen, subgroupSize};
SmallVector<int> instDataB = {subgroupSize, maxBLen};
-
+ SmallVector<int> instDataCD;
+ if (hasAcc) {
+ const unsigned dataCLen = bTy.getShape().back();
+ auto supportedCLen =
+ uArchInstruction->getSupportedN(cTy.getElementType());
+ const int maxCLen =
+ xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxCLen == -1) {
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ return;
+ }
+ instDataCD = {maxALen, maxCLen};
+ }
if (layoutKind == LayoutKind::InstData) {
dpasALayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
dpasBLayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
- } else {
+ if (hasAcc) {
+ dpasCDLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
+ }
+ } else if (layoutKind == LayoutKind::Lane) {
dpasALayout = getSIMTLayoutInfoForDPASOperand(
aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
dpasBLayout = getSIMTLayoutInfoForDPASOperand(
bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
- }
-
- if (operands.size() > 2) {
- VectorType cTy = dpas.getAccType();
- if (layoutKind == LayoutKind::InstData) {
- const unsigned dataCLen = bTy.getShape().back();
- auto supportedCLen =
- uArchInstruction->getSupportedN(bTy.getElementType());
- const int maxCLen = xegpu::getLargestDivisor(
- dataCLen, ArrayRef<unsigned>(supportedCLen));
- if (maxCLen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataC = {maxALen, maxCLen};
- dpasCDLayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
- } else
+ if (hasAcc) {
dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ }
+ } else { // Subgroup
+ auto numSgOrErr = getNumSg(dpas, subgroupSize);
+ if (failed(numSgOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine the number of subgroups for the operation.");
+ return;
+ }
- dpas.setLayoutCdAttr(
- dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
+ // Step 1. Get all valid layouts for A, B, and C operands.
+ // All operands must have at least one valid subgroup layout.
+ LayoutInfo layoutD = results[0]->getValue();
+ SmallVector<int> sgLayoutD = layoutD.getSgLayout();
+ assert(!sgLayoutD.empty() && "Expected layout for DPAS result.");
+ auto layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+
+ auto layoutsA =
+ getValidLayouts(aTy.getShape(), instDataA, numSgOrErr.value());
+ auto layoutsB =
+ getValidLayouts(bTy.getShape(), instDataB, numSgOrErr.value());
+ SmallVector<std::pair<int, int>> layoutsC;
+ if (hasAcc)
+ layoutsC =
+ getValidLayouts(cTy.getShape(), instDataCD, numSgOrErr.value());
+
+ if (layoutsA.empty() || layoutsB.empty() ||
+ (hasAcc && layoutsC.empty())) {
+ dpas.emitWarning(
+ "Unable to determine suitable subgroup layout for A/B/C matrices.");
+ return;
+ }
+
+ // Step 2. If the result D layout can be reused for all operands, that
+ // layout is chosen. Otherwise, pick the most balanced subgroup layout
+ // that is valid for A, B and C (if present) operands
+ llvm::DenseSet<std::pair<int, int>> setA(layoutsA.begin(),
+ layoutsA.end());
+ llvm::DenseSet<std::pair<int, int>> setC;
+ if (hasAcc)
+ setC = llvm::DenseSet<std::pair<int, int>>(layoutsC.begin(),
+ layoutsC.end());
+ std::optional<std::pair<int, int>> bestPick;
+ for (auto &l : layoutsB) {
+ // Is in valid A layouts
+ if (setA.contains(l)) {
+ // Is in valid C layouts
+ if (hasAcc && !setC.contains(l))
+ continue;
+ // Is in (A and B and C) and matches D -> best pick
+ if (l == layoutDVal) {
+ bestPick = l;
+ break;
+ }
+ // Is in (A and B and C), balanced layout comes first
+ if (!bestPick)
+ bestPick = l;
+ }
+ }
+ // Step 3. If there is no subgroup layout compatible with A, B and C (if
+ // present) operands, we fail.
+ SmallVector<int> sgLayoutA;
+ SmallVector<int> sgLayoutB;
+ SmallVector<int> sgLayoutC;
+ if (bestPick) {
+ sgLayoutA = {bestPick->first, bestPick->second};
+ sgLayoutB = sgLayoutA;
+ sgLayoutC = sgLayoutA;
+ } else {
+ dpas.emitWarning("Unable to find common subgroup layout for matrices.");
+ return;
+ }
+ SmallVector<int> sgDataA = {
+ static_cast<int>(aTy.getShape()[0]) / sgLayoutA[0],
+ static_cast<int>(aTy.getShape()[1]) / sgLayoutA[1]};
+ SmallVector<int> sgDataB = {
+ static_cast<int>(bTy.getShape()[0]) / sgLayoutB[0],
+ static_cast<int>(bTy.getShape()[1]) / sgLayoutB[1]};
+ SmallVector<int> sgDataC;
+ if (hasAcc)
+ sgDataC = {static_cast<int>(dpas.getResultType().getShape()[0]) /
+ sgLayoutC[0],
+ static_cast<int>(dpas.getResultType().getShape()[1]) /
+ sgLayoutC[1]};
+
+ dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
+ aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
+ DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+
+ dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
+ bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
+ DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+ if (hasAcc) {
+ dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
+ cTy.getContext(),
+ DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutD),
+ DenseI32ArrayAttr::get(cTy.getContext(), sgDataC),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+ }
}
dpas.setLayoutAAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
dpas.setLayoutBAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
+ if (hasAcc)
+ dpas.setLayoutCdAttr(
----------------
Jianhui-Li wrote:
Just comment for future PR: Note that once we have layout conflict mechanism set up, dpas should setLayout for C/D always. Since the layout conflict detection needs to read DPAS's output layout and compare with its use.
https://github.com/llvm/llvm-project/pull/172867
More information about the Mlir-commits
mailing list