[Mlir-commits] [mlir] [MLIR][XeGPU] Use the `setupDpasLayout` utility for dpas layout propagation (PR #180937)
Artem Kroviakov
llvmlistbot at llvm.org
Thu Feb 12 03:04:10 PST 2026
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/180937
>From cb5e7a6d8215dfe9abd31f072def78d7a4692292 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 11 Feb 2026 13:16:22 +0000
Subject: [PATCH 1/5] [MLIR][XeGPU] Use the setupLayout utility for dpas
propagation
---
.../XeGPU/Transforms/XeGPULayoutImpl.h | 8 +
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 207 ++++++++++++++++-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 219 +++---------------
3 files changed, 245 insertions(+), 189 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 182607c22c584..0d5210f07f05a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -161,6 +161,14 @@ DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
VectorType vectorTy,
const uArch::uArch *uArch);
+/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
+/// The numSg and consumerLayout (optional) are only used by sg layout creation.
+std::optional<std::tuple<DistributeLayoutAttr, DistributeLayoutAttr,
+ DistributeLayoutAttr>>
+setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
+ VectorType cdTy, DistributeLayoutAttr consumerLayout,
+ const uArch::uArch *uArch, int numSg);
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index a4e47fca96d34..83f0180961e65 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -852,4 +852,209 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
srcShape, subgroupSize);
-}
\ No newline at end of file
+}
+
+// This function returns the default lane layout for a given vector type.
+// - `packingSize` means multiple consecutive elements can be accessed together
+// as a single unit.
+// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
+// 1x2xf16 w/o vnni).
+template <typename RankedTy>
+static xegpu::LayoutAttr
+getDefaultLaneLayout(RankedTy ty, const xegpu::uArch::uArch *uArch,
+ std::optional<unsigned> packingSize = std::nullopt,
+ bool vnni = false) {
+ // Expecting a 1D or 2D vector.
+ assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
+ "Expected 1D non-vnni or 2D vector.");
+ // Expecting int or float element type.
+ assert(ty.getElementType().isIntOrFloat() &&
+ "Expected int or float element type.");
+
+ auto context = ty.getContext();
+ auto rank = ty.getRank();
+ SmallVector<int> laneLayout(rank, 1);
+ SmallVector<int> laneData(rank, 1);
+ if (packingSize.has_value()) {
+ unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
+ int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
+ laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
+ }
+ laneLayout.back() = uArch->getSubgroupSize();
+ return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+}
+
+// This function returns all layouts for the given sgCount, whose sgData:
+// 1. Evenly divides the wgShape.
+// 2. Is a multiple of instData.
+// Example:
+// wgShape = [128, 64], instData = [8, 16], sgCount = 32
+// Returns layouts:
+// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
+using LayoutRepresentation = std::pair<int64_t, int64_t>;
+SmallVector<LayoutRepresentation> getValidLayouts(ArrayRef<int64_t> wgShape,
+ ArrayRef<int64_t> instData,
+ int64_t sgCount) {
+ SmallVector<LayoutRepresentation> candidates;
+ for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
+ if (sgCount % sgLayout0)
+ continue;
+ int64_t sgLayout1 = sgCount / sgLayout0;
+ int64_t sgData0 = wgShape[0] / sgLayout0;
+ int64_t sgData1 = wgShape[1] / sgLayout1;
+ if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
+ (sgData0 % instData[0] || sgData1 % instData[1]))
+ continue;
+ candidates.emplace_back(sgLayout0, sgLayout1);
+ }
+ // Sort primarily by how balanced they are
+ // (i.e., minimize the absolute difference between the two dimensions), and
+ // secondarily by the first dimension in ascending order.
+ llvm::sort(candidates, [](const LayoutRepresentation &lhs,
+ const LayoutRepresentation &rhs) {
+ int diffLhs = std::abs(lhs.first - lhs.second);
+ int diffRhs = std::abs(rhs.first - rhs.second);
+ if (diffLhs != diffRhs)
+ return diffLhs < diffRhs;
+ return lhs.first < rhs.first;
+ });
+ return candidates;
+}
+
+/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
+/// The numSg and consumerLayout (optional) are only used by sg layout creation.
+std::optional<
+ std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
+ xegpu::DistributeLayoutAttr>>
+xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
+ VectorType bTy, VectorType cdTy,
+ xegpu::DistributeLayoutAttr consumerLayout,
+ const xegpu::uArch::uArch *uArch, int numSg) {
+ auto context = aTy.getContext();
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+ if (layoutKind == xegpu::LayoutKind::Subgroup) {
+ assert(numSg > 0 &&
+ "Number of subgroups must be provided for sg layout creation.");
+ auto instDataLayouts =
+ xegpu::setupDpasLayout(xegpu::LayoutKind::InstData, aTy, bTy, cdTy,
+ consumerLayout, uArch, numSg);
+ if (!instDataLayouts)
+ return std::nullopt;
+ auto [instLayoutA, instLayoutB, instLayoutCD] = *instDataLayouts;
+ SmallVector<int64_t> instDataA = instLayoutA.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> instDataB = instLayoutB.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> instDataCD = instLayoutCD.getEffectiveInstDataAsInt();
+ assert(instDataA.size() == 2 && instDataB.size() == 2 &&
+ instDataCD.size() == 2 &&
+ "Sg layout creation expects valid 2D inst data");
+
+ std::optional<LayoutRepresentation> layoutDVal = std::nullopt;
+ if (consumerLayout) {
+ SmallVector<int64_t> sgLayoutD =
+ consumerLayout.getEffectiveSgLayoutAsInt();
+ layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+ }
+
+ auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
+ auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
+ auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
+ if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
+ return std::nullopt;
+
+ // Step 2. If the result D layout can be reused for all operands, that
+ // layout is chosen. Otherwise, pick the most balanced subgroup layout
+ // that is valid for A, B and C (if present) operands
+ llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
+ llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
+ layoutsCD.end());
+ std::optional<LayoutRepresentation> bestPick;
+ for (auto &l : layoutsB) {
+ if (setA.contains(l) && setCD.contains(l)) {
+ // Is in (A and B and CD) and matches consumer -> best pick
+ if (layoutDVal.has_value() && l == *layoutDVal) {
+ bestPick = l;
+ break;
+ }
+ // Is in (A and B and CD), balanced layout comes first
+ if (!bestPick)
+ bestPick = l;
+ }
+ }
+ // Step 3. If there is no subgroup layout compatible with A, B and C (if
+ // present) operands, we fail.
+ if (!bestPick)
+ return std::nullopt;
+ SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
+ static_cast<int>(bestPick->second)};
+ SmallVector<int> sgDataA = {
+ static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
+ static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
+ SmallVector<int> sgDataB = {
+ static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
+ static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
+ SmallVector<int> sgDataCD = {
+ static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
+ static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
+
+ auto dpasALayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout),
+ DenseI32ArrayAttr::get(context, sgDataA),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr);
+
+ auto dpasBLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout),
+ DenseI32ArrayAttr::get(context, sgDataB),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr);
+
+ auto dpasCDLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout),
+ DenseI32ArrayAttr::get(context, sgDataCD),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr);
+ return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
+ } else if (layoutKind == xegpu::LayoutKind::InstData) {
+ const int subgroupSize = uArch->getSubgroupSize();
+ const unsigned dataALen = aTy.getShape().front();
+ auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
+ const int maxALen =
+ xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+
+ const unsigned dataBLen = bTy.getShape().back();
+ auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
+ const int maxBLen =
+ xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+
+ auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
+ const int maxCLen =
+ xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
+ return std::nullopt;
+
+ SmallVector<int> instDataA(aTy.getRank(), 1);
+ instDataA[aTy.getRank() - 2] = maxALen;
+ instDataA[aTy.getRank() - 1] = subgroupSize;
+ SmallVector<int> instDataB(bTy.getRank(), 1);
+ instDataB[bTy.getRank() - 2] = subgroupSize;
+ instDataB[bTy.getRank() - 1] = maxBLen;
+ SmallVector<int> instDataCD(cdTy.getRank(), 1);
+ instDataCD[cdTy.getRank() - 2] = maxALen;
+ instDataCD[cdTy.getRank() - 1] = maxCLen;
+ return std::make_tuple(xegpu::LayoutAttr::get(context, instDataA),
+ xegpu::LayoutAttr::get(context, instDataB),
+ xegpu::LayoutAttr::get(context, instDataCD));
+ } else if (layoutKind == xegpu::LayoutKind::Lane) {
+ auto aLayout = getDefaultLaneLayout(
+ aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
+ auto bLayout = getDefaultLaneLayout(
+ bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
+ auto cdLayout = getDefaultLaneLayout(
+ cdTy, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ return std::make_tuple(aLayout, bLayout, cdLayout);
+ }
+ return std::nullopt;
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index ccfab7350e351..907a2c9cced90 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -308,33 +308,6 @@ static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
}
-/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
-/// is set according to the following criteria:
-/// * For A operand, the data must be packed in minimum
-/// `packedSizeInBitsForDefault`
-/// * For B operand, the data must be packed in minimum
-/// `packedSizeInBitsForDpasB`
-static LayoutInfo
-getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
- const xegpu::uArch::uArch *uArch,
- unsigned packingSize) {
- Type elementTy = vectorTy.getElementType();
- assert(elementTy.isIntOrFloat() &&
- "Expected int or float type in DPAS operands");
- SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()});
- // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
- // must have the VNNI format.
- if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
- SmallVector<int32_t, 2> data(
- {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
- 1});
- return LayoutInfo(
- xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
- }
- // Otherwise, return the default layout for the vector type.
- return getSIMTLayoutInfoBlockIO(vectorTy, uArch, packingSize);
-}
-
//===----------------------------------------------------------------------===//
// LayoutInfoPropagation
//===----------------------------------------------------------------------===//
@@ -765,180 +738,50 @@ void LayoutInfoPropagation::visitDpasOp(
dpasBLayout = LayoutInfo(anchorLayoutB);
dpasCDLayout = LayoutInfo(anchorLayoutCD);
} else {
+ auto uArch = getUArch(getChipStr(dpas).value_or(""));
VectorType aTy = dpas.getLhsType();
VectorType bTy = dpas.getRhsType();
- VectorType cTy;
- const bool hasAcc = operands.size() > 2;
- if (hasAcc)
- cTy = dpas.getAccType();
-
- auto uArch = getUArch(getChipStr(dpas).value_or(""));
- const int subgroupSize = uArch->getSubgroupSize();
- const auto *uArchInstruction =
- dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
- xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
-
- const unsigned dataALen = aTy.getShape().front();
- auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
- const int maxALen =
- xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
- if (maxALen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
-
- const unsigned dataBLen = bTy.getShape().back();
- auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
-
- const int maxBLen =
- xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
-
- if (maxBLen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataA = {maxALen, subgroupSize};
- SmallVector<int> instDataB = {subgroupSize, maxBLen};
- SmallVector<int> instDataCD;
- if (hasAcc) {
- const unsigned dataCLen = bTy.getShape().back();
- auto supportedCLen =
- uArchInstruction->getSupportedN(cTy.getElementType());
- const int maxCLen =
- xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
- if (maxCLen == -1) {
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- return;
- }
- instDataCD = {maxALen, maxCLen};
- }
- if (layoutKind == xegpu::LayoutKind::InstData) {
- dpasALayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
- dpasBLayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
- if (hasAcc) {
- dpasCDLayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
- }
- } else if (layoutKind == xegpu::LayoutKind::Lane) {
- dpasALayout = getSIMTLayoutInfoForDPASOperand(
- aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
- dpasBLayout = getSIMTLayoutInfoForDPASOperand(
- bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
- if (hasAcc) {
- dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
- cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
- }
- } else { // Subgroup
- auto numSgOrErr = getNumSg(dpas, subgroupSize);
+ VectorType cdTy = dpas.getResultType();
+ LayoutInfo consumerLayout = results[0]->getValue();
+ if (!consumerLayout.isAssigned())
+ dpas.emitWarning("No consumer layout was found for the DPAS result.");
+
+ auto consumerLayoutAttr =
+ dyn_cast_if_present<xegpu::DistributeLayoutAttr>(consumerLayout.get());
+ xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
+ requiredBLayout;
+
+ int numSg = 0;
+ if (layoutKind == xegpu::LayoutKind::Subgroup) {
+ auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize());
if (failed(numSgOrErr)) {
dpas.emitWarning(
"Unable to determine the number of subgroups for the operation.");
return;
}
+ numSg = numSgOrErr.value();
+ }
+ auto layouts = xegpu::setupDpasLayout(layoutKind, aTy, bTy, cdTy,
+ consumerLayoutAttr, uArch, numSg);
+ if (!layouts.has_value()) {
+ dpas.emitWarning(
+ "Failed to determine required layouts for DPAS operands.");
+ return;
+ }
- // Step 1. Get all valid layouts for A, B, and C operands.
- // All operands must have at least one valid subgroup layout.
- LayoutInfo layoutD = results[0]->getValue();
- SmallVector<int> sgLayoutD = layoutD.getSgLayout();
- assert(!sgLayoutD.empty() && "Expected layout for DPAS result.");
- auto layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
-
- auto layoutsA =
- getValidLayouts(aTy.getShape(), instDataA, numSgOrErr.value());
- auto layoutsB =
- getValidLayouts(bTy.getShape(), instDataB, numSgOrErr.value());
- SmallVector<std::pair<int, int>> layoutsC;
- if (hasAcc)
- layoutsC =
- getValidLayouts(cTy.getShape(), instDataCD, numSgOrErr.value());
-
- if (layoutsA.empty() || layoutsB.empty() ||
- (hasAcc && layoutsC.empty())) {
- dpas.emitWarning(
- "Unable to determine suitable subgroup layout for A/B/C matrices.");
- return;
- }
-
- // Step 2. If the result D layout can be reused for all operands, that
- // layout is chosen. Otherwise, pick the most balanced subgroup layout
- // that is valid for A, B and C (if present) operands
- llvm::DenseSet<std::pair<int, int>> setA(layoutsA.begin(),
- layoutsA.end());
- llvm::DenseSet<std::pair<int, int>> setC;
- if (hasAcc)
- setC = llvm::DenseSet<std::pair<int, int>>(layoutsC.begin(),
- layoutsC.end());
- std::optional<std::pair<int, int>> bestPick;
- for (auto &l : layoutsB) {
- if (setA.contains(l)) {
- if (hasAcc && !setC.contains(l))
- continue;
- // Is in (A and B and C) and matches D -> best pick
- if (l == layoutDVal) {
- bestPick = l;
- break;
- }
- // Is in (A and B and C), balanced layout comes first
- if (!bestPick)
- bestPick = l;
- }
- }
- // Step 3. If there is no subgroup layout compatible with A, B and C (if
- // present) operands, we fail.
- SmallVector<int> sgLayout;
- if (bestPick) {
- sgLayout = {bestPick->first, bestPick->second};
- } else {
- dpas.emitWarning("Unable to find common subgroup layout for matrices.");
- return;
- }
- SmallVector<int> sgDataA = {
- static_cast<int>(aTy.getShape()[0]) / sgLayout[0],
- static_cast<int>(aTy.getShape()[1]) / sgLayout[1]};
- SmallVector<int> sgDataB = {
- static_cast<int>(bTy.getShape()[0]) / sgLayout[0],
- static_cast<int>(bTy.getShape()[1]) / sgLayout[1]};
- SmallVector<int> sgDataC;
- if (hasAcc)
- sgDataC = {
- static_cast<int>(dpas.getResultType().getShape()[0]) / sgLayout[0],
- static_cast<int>(dpas.getResultType().getShape()[1]) / sgLayout[1]};
-
- dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
- aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayout),
- DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
- /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
- /*lane_data =*/nullptr, /*order =*/nullptr));
+ std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
- dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
- bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayout),
- DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
- /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
- /*lane_data =*/nullptr, /*order =*/nullptr));
- if (hasAcc) {
- dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
- cTy.getContext(),
- DenseI32ArrayAttr::get(cTy.getContext(), sgLayout),
- DenseI32ArrayAttr::get(cTy.getContext(), sgDataC),
- /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
- /*lane_data =*/nullptr, /*order =*/nullptr));
- }
- }
- dpas.setLayoutAAttr(
- dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
- dpas.setLayoutBAttr(
- dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
- if (hasAcc)
- dpas.setLayoutCdAttr(
- dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
+ dpas.setLayoutAAttr(requiredALayout);
+ dpas.setLayoutBAttr(requiredBLayout);
+ dpas.setLayoutCdAttr(requiredCDLayoutAttr);
+ dpasALayout = LayoutInfo(requiredALayout);
+ dpasBLayout = LayoutInfo(requiredBLayout);
+ dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
}
-
propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
- if (operands.size() > 2) {
+ if (operands.size() > 2)
propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
- }
}
/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
>From 5e508e4b374267b7be50b6d9bfcd9bd79f4e7831 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 11 Feb 2026 14:24:15 +0000
Subject: [PATCH 2/5] Revisit dpas later when the layout is assigned.
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 907a2c9cced90..711b7247932c0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -738,13 +738,14 @@ void LayoutInfoPropagation::visitDpasOp(
dpasBLayout = LayoutInfo(anchorLayoutB);
dpasCDLayout = LayoutInfo(anchorLayoutCD);
} else {
+ LayoutInfo consumerLayout = results[0]->getValue();
+ if (!consumerLayout.isAssigned())
+ return;
+
auto uArch = getUArch(getChipStr(dpas).value_or(""));
VectorType aTy = dpas.getLhsType();
VectorType bTy = dpas.getRhsType();
VectorType cdTy = dpas.getResultType();
- LayoutInfo consumerLayout = results[0]->getValue();
- if (!consumerLayout.isAssigned())
- dpas.emitWarning("No consumer layout was found for the DPAS result.");
auto consumerLayoutAttr =
dyn_cast_if_present<xegpu::DistributeLayoutAttr>(consumerLayout.get());
>From 8f9be189c1ce0305e10557f5a1b41e0b35daf114 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 11 Feb 2026 15:42:58 +0000
Subject: [PATCH 3/5] Add dpas loop test
---
.../XeGPU/propagate-layout-subgroup.mlir | 51 ++++++++++++++++++-
1 file changed, 50 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 190b54912488f..e987a1c9725ca 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -163,4 +163,53 @@ gpu.module @test {
xegpu.store %reduce, %dst[%offset], %mask {layout = #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 16]>, dims = [1]>} : vector<32xf32>, memref<32xf32>, vector<32xindex>, vector<32xi1>
gpu.return
}
-}
\ No newline at end of file
+}
+
+// -----
+gpu.module @test {
+ // CHECK-LABEL: for_loop_dpas
+ gpu.func @for_loop_dpas(%arg0: memref<2048x8192xf16>, %arg1: memref<8192x4096xf16>, %arg2: memref<2048x4096xf32>) kernel attributes {known_block_size = array<i32: 8, 1, 16>} {
+ %cst = arith.constant dense<0.000000e+00> : vector<128x128xf32>
+ %c128 = arith.constant 128 : index
+ %c8192 = arith.constant 8192 : index
+ %c0 = arith.constant 0 : index
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %0 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%block_id_x]
+ %1 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%block_id_y]
+ // CHECK: %2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (vector<128x128xf32>) {
+ // CHECK-NEXT: xegpu.create_nd_tdesc %{{.*}} : memref<2048x8192xf16> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>,
+ // CHECK-SAME: #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>>
+
+ // CHECK-NEXT: xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}>
+
+ // CHECK-NEXT: xegpu.create_nd_tdesc %{{.*}} : memref<8192x4096xf16> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>,
+ // CHECK-SAME: #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>>
+
+ // CHECK-NEXT: xegpu.load_nd %6[%arg3, %block_id_y] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}>
+
+ // CHECK-NEXT: xegpu.dpas %{{.*}} {
+ // CHECK-SAME: layout_a = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>,
+ // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>,
+ // CHECK-SAME: layout_cd = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}
+ // CHECK-SAME: : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
+
+ // CHECK-NEXT: scf.yield %{{.*}} : vector<128x128xf32>
+ // CHECK-NEXT: } {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}
+ // CHECK: xegpu.store_nd %{{.*}} <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}>
+
+ %2 = scf.for %arg3 = %c0 to %c8192 step %c128 iter_args(%arg4 = %cst) -> (vector<128x128xf32>) {
+ %4 = xegpu.create_nd_tdesc %arg0 : memref<2048x8192xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+ %5 = xegpu.load_nd %4[%block_id_x, %arg3] : !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<128x128xf16>
+ %6 = xegpu.create_nd_tdesc %arg1 : memref<8192x4096xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+ %7 = xegpu.load_nd %6[%arg3, %block_id_y] : !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<128x128xf16>
+ %8 = xegpu.dpas %5, %7, %arg4 : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
+ scf.yield %8 : vector<128x128xf32>
+ }
+ %3 = xegpu.create_nd_tdesc %arg2 : memref<2048x4096xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+ xegpu.store_nd %2, %3[%0, %1] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+ gpu.return
+ }
+}
>From 7b02102e0344e52ffa64bda4effff24d107244c8 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 11 Feb 2026 15:57:30 +0000
Subject: [PATCH 4/5] Only care about consumer for the sg layout
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 711b7247932c0..bc309c9029878 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -738,22 +738,22 @@ void LayoutInfoPropagation::visitDpasOp(
dpasBLayout = LayoutInfo(anchorLayoutB);
dpasCDLayout = LayoutInfo(anchorLayoutCD);
} else {
- LayoutInfo consumerLayout = results[0]->getValue();
- if (!consumerLayout.isAssigned())
- return;
-
auto uArch = getUArch(getChipStr(dpas).value_or(""));
VectorType aTy = dpas.getLhsType();
VectorType bTy = dpas.getRhsType();
VectorType cdTy = dpas.getResultType();
- auto consumerLayoutAttr =
- dyn_cast_if_present<xegpu::DistributeLayoutAttr>(consumerLayout.get());
+ xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr;
xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
requiredBLayout;
int numSg = 0;
if (layoutKind == xegpu::LayoutKind::Subgroup) {
+ LayoutInfo consumerLayout = results[0]->getValue();
+ if (!consumerLayout.isAssigned())
+ return;
+ consumerLayoutAttr =
+ dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize());
if (failed(numSgOrErr)) {
dpas.emitWarning(
>From 05f69ca25245b69e339be081d1cfa436f9e0c834 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 12 Feb 2026 11:03:49 +0000
Subject: [PATCH 5/5] Feedback
---
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 127 ++++++++++--------
1 file changed, 70 insertions(+), 57 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 83f0180961e65..ca62b2893eb68 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -860,10 +860,9 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
// 1x2xf16 w/o vnni).
template <typename RankedTy>
-static xegpu::LayoutAttr
-getDefaultLaneLayout(RankedTy ty, const xegpu::uArch::uArch *uArch,
- std::optional<unsigned> packingSize = std::nullopt,
- bool vnni = false) {
+static xegpu::LayoutAttr getDefaultLaneLayout2DBlockIo(
+ RankedTy ty, const xegpu::uArch::uArch *uArch,
+ std::optional<unsigned> packingSize = std::nullopt, bool vnni = false) {
// Expecting a 1D or 2D vector.
assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
"Expected 1D non-vnni or 2D vector.");
@@ -892,9 +891,9 @@ getDefaultLaneLayout(RankedTy ty, const xegpu::uArch::uArch *uArch,
// Returns layouts:
// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
using LayoutRepresentation = std::pair<int64_t, int64_t>;
-SmallVector<LayoutRepresentation> getValidLayouts(ArrayRef<int64_t> wgShape,
- ArrayRef<int64_t> instData,
- int64_t sgCount) {
+static SmallVector<LayoutRepresentation>
+getValidLayouts(ArrayRef<int64_t> wgShape, ArrayRef<int64_t> instData,
+ int64_t sgCount) {
SmallVector<LayoutRepresentation> candidates;
for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
if (sgCount % sgLayout0)
@@ -921,7 +920,7 @@ SmallVector<LayoutRepresentation> getValidLayouts(ArrayRef<int64_t> wgShape,
return candidates;
}
-/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
+/// Sets up the anchor layouts for dpas operands (A, B, and C/D).
/// The numSg and consumerLayout (optional) are only used by sg layout creation.
std::optional<
std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
@@ -935,52 +934,84 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+ auto getInstDataVectors = [&]()
+ -> std::optional<std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
+ SmallVector<int64_t>>> {
+ const int subgroupSize = uArch->getSubgroupSize();
+ const unsigned dataALen = aTy.getShape().front();
+ auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
+ const int maxALen =
+ xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+
+ const unsigned dataBLen = bTy.getShape().back();
+ auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
+ const int maxBLen =
+ xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+
+ auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
+ const int maxCLen =
+ xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
+ return std::nullopt;
+
+ SmallVector<int64_t> instDataA(aTy.getRank(), 1);
+ instDataA[aTy.getRank() - 2] = maxALen;
+ instDataA[aTy.getRank() - 1] = subgroupSize;
+ SmallVector<int64_t> instDataB(bTy.getRank(), 1);
+ instDataB[bTy.getRank() - 2] = subgroupSize;
+ instDataB[bTy.getRank() - 1] = maxBLen;
+ SmallVector<int64_t> instDataCD(cdTy.getRank(), 1);
+ instDataCD[cdTy.getRank() - 2] = maxALen;
+ instDataCD[cdTy.getRank() - 1] = maxCLen;
+ return std::make_tuple(instDataA, instDataB, instDataCD);
+ };
+
if (layoutKind == xegpu::LayoutKind::Subgroup) {
assert(numSg > 0 &&
"Number of subgroups must be provided for sg layout creation.");
- auto instDataLayouts =
- xegpu::setupDpasLayout(xegpu::LayoutKind::InstData, aTy, bTy, cdTy,
- consumerLayout, uArch, numSg);
- if (!instDataLayouts)
+ auto instDataVecs = getInstDataVectors();
+ if (!instDataVecs)
return std::nullopt;
- auto [instLayoutA, instLayoutB, instLayoutCD] = *instDataLayouts;
- SmallVector<int64_t> instDataA = instLayoutA.getEffectiveInstDataAsInt();
- SmallVector<int64_t> instDataB = instLayoutB.getEffectiveInstDataAsInt();
- SmallVector<int64_t> instDataCD = instLayoutCD.getEffectiveInstDataAsInt();
+ auto [instDataA, instDataB, instDataCD] = *instDataVecs;
assert(instDataA.size() == 2 && instDataB.size() == 2 &&
instDataCD.size() == 2 &&
"Sg layout creation expects valid 2D inst data");
- std::optional<LayoutRepresentation> layoutDVal = std::nullopt;
- if (consumerLayout) {
+ std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
+ if (consumerLayout && consumerLayout.isForWorkgroup()) {
SmallVector<int64_t> sgLayoutD =
consumerLayout.getEffectiveSgLayoutAsInt();
- layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+ consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
}
+ // Step 1. Get all valid layouts for A, B and C/D operands.
+ // Order them from most balanced to least balanced.
auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
return std::nullopt;
- // Step 2. If the result D layout can be reused for all operands, that
+ // Step 2. If the consumer layout can be reused for all operands, that
// layout is chosen. Otherwise, pick the most balanced subgroup layout
// that is valid for A, B and C (if present) operands
llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
layoutsCD.end());
std::optional<LayoutRepresentation> bestPick;
- for (auto &l : layoutsB) {
- if (setA.contains(l) && setCD.contains(l)) {
+ for (auto &sgLayout : layoutsB) {
+ if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
// Is in (A and B and CD) and matches consumer -> best pick
- if (layoutDVal.has_value() && l == *layoutDVal) {
- bestPick = l;
+ if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
+ bestPick = sgLayout;
break;
}
- // Is in (A and B and CD), balanced layout comes first
+ // Is in (A and B and CD) layoutsB is ordered from most
+ // balanced to least. So the first one we see is the most balanced one,
+ // remember it and later only update if there is one that matches the
+ // consumer.
if (!bestPick)
- bestPick = l;
+ bestPick = sgLayout;
}
}
// Step 3. If there is no subgroup layout compatible with A, B and C (if
@@ -1018,41 +1049,23 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
/*lane_data =*/nullptr, /*order =*/nullptr);
return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
} else if (layoutKind == xegpu::LayoutKind::InstData) {
- const int subgroupSize = uArch->getSubgroupSize();
- const unsigned dataALen = aTy.getShape().front();
- auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
- const int maxALen =
- xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
-
- const unsigned dataBLen = bTy.getShape().back();
- auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
- const int maxBLen =
- xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
-
- auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
- const int maxCLen =
- xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
- if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
+ auto instDataVecs = getInstDataVectors();
+ if (!instDataVecs)
return std::nullopt;
-
- SmallVector<int> instDataA(aTy.getRank(), 1);
- instDataA[aTy.getRank() - 2] = maxALen;
- instDataA[aTy.getRank() - 1] = subgroupSize;
- SmallVector<int> instDataB(bTy.getRank(), 1);
- instDataB[bTy.getRank() - 2] = subgroupSize;
- instDataB[bTy.getRank() - 1] = maxBLen;
- SmallVector<int> instDataCD(cdTy.getRank(), 1);
- instDataCD[cdTy.getRank() - 2] = maxALen;
- instDataCD[cdTy.getRank() - 1] = maxCLen;
- return std::make_tuple(xegpu::LayoutAttr::get(context, instDataA),
- xegpu::LayoutAttr::get(context, instDataB),
- xegpu::LayoutAttr::get(context, instDataCD));
+ auto [instDataA, instDataB, instDataCD] = *instDataVecs;
+ return std::make_tuple(
+ xegpu::LayoutAttr::get(
+ context, SmallVector<int>(instDataA.begin(), instDataA.end())),
+ xegpu::LayoutAttr::get(
+ context, SmallVector<int>(instDataB.begin(), instDataB.end())),
+ xegpu::LayoutAttr::get(
+ context, SmallVector<int>(instDataCD.begin(), instDataCD.end())));
} else if (layoutKind == xegpu::LayoutKind::Lane) {
- auto aLayout = getDefaultLaneLayout(
+ auto aLayout = getDefaultLaneLayout2DBlockIo(
aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
- auto bLayout = getDefaultLaneLayout(
+ auto bLayout = getDefaultLaneLayout2DBlockIo(
bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
- auto cdLayout = getDefaultLaneLayout(
+ auto cdLayout = getDefaultLaneLayout2DBlockIo(
cdTy, uArch, uArchInstruction->getPackedFormatBitSizeB());
return std::make_tuple(aLayout, bLayout, cdLayout);
}
More information about the Mlir-commits
mailing list