[Mlir-commits] [mlir] [MLIR][XeGPU] Use the `setupDpasLayout` utility for dpas layout propagation (PR #180937)
Artem Kroviakov
llvmlistbot at llvm.org
Thu Feb 12 02:27:59 PST 2026
================
@@ -852,4 +852,209 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
srcShape, subgroupSize);
-}
\ No newline at end of file
+}
+
+// This function returns the default lane layout for a given vector type.
+// - `packingSize` means multiple consecutive elements can be accessed together
+// as a single unit.
+// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
+// 1x2xf16 w/o vnni).
+template <typename RankedTy>
+static xegpu::LayoutAttr
+getDefaultLaneLayout(RankedTy ty, const xegpu::uArch::uArch *uArch,
+ std::optional<unsigned> packingSize = std::nullopt,
+ bool vnni = false) {
+ // Expecting a 1D or 2D vector.
+ assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
+ "Expected 1D non-vnni or 2D vector.");
+ // Expecting int or float element type.
+ assert(ty.getElementType().isIntOrFloat() &&
+ "Expected int or float element type.");
+
+ auto context = ty.getContext();
+ auto rank = ty.getRank();
+ SmallVector<int> laneLayout(rank, 1);
+ SmallVector<int> laneData(rank, 1);
+ if (packingSize.has_value()) {
+ unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
+ int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
+ laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
+ }
+ laneLayout.back() = uArch->getSubgroupSize();
+ return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+}
+
+// This function returns all layouts for the given sgCount, whose sgData:
+// 1. Evenly divides the wgShape.
+// 2. Is a multiple of instData.
+// Example:
+// wgShape = [128, 64], instData = [8, 16], sgCount = 32
+// Returns layouts:
+// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
+using LayoutRepresentation = std::pair<int64_t, int64_t>;
+SmallVector<LayoutRepresentation> getValidLayouts(ArrayRef<int64_t> wgShape,
+ ArrayRef<int64_t> instData,
+ int64_t sgCount) {
+ SmallVector<LayoutRepresentation> candidates;
+ for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
+ if (sgCount % sgLayout0)
+ continue;
+ int64_t sgLayout1 = sgCount / sgLayout0;
+ int64_t sgData0 = wgShape[0] / sgLayout0;
+ int64_t sgData1 = wgShape[1] / sgLayout1;
+ if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
+ (sgData0 % instData[0] || sgData1 % instData[1]))
+ continue;
+ candidates.emplace_back(sgLayout0, sgLayout1);
+ }
+ // Sort primarily by how balanced they are
+ // (i.e., minimize the absolute difference between the two dimensions), and
+ // secondarily by the first dimension in ascending order.
+ llvm::sort(candidates, [](const LayoutRepresentation &lhs,
+ const LayoutRepresentation &rhs) {
+ int diffLhs = std::abs(lhs.first - lhs.second);
+ int diffRhs = std::abs(rhs.first - rhs.second);
+ if (diffLhs != diffRhs)
+ return diffLhs < diffRhs;
+ return lhs.first < rhs.first;
+ });
+ return candidates;
+}
+
+/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
+/// The numSg and consumerLayout (optional) are only used by sg layout creation.
+std::optional<
+ std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
+ xegpu::DistributeLayoutAttr>>
+xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
+ VectorType bTy, VectorType cdTy,
+ xegpu::DistributeLayoutAttr consumerLayout,
+ const xegpu::uArch::uArch *uArch, int numSg) {
+ auto context = aTy.getContext();
+ const auto *uArchInstruction =
+ dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+ xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+ if (layoutKind == xegpu::LayoutKind::Subgroup) {
+ assert(numSg > 0 &&
+ "Number of subgroups must be provided for sg layout creation.");
+ auto instDataLayouts =
+ xegpu::setupDpasLayout(xegpu::LayoutKind::InstData, aTy, bTy, cdTy,
+ consumerLayout, uArch, numSg);
+ if (!instDataLayouts)
+ return std::nullopt;
+ auto [instLayoutA, instLayoutB, instLayoutCD] = *instDataLayouts;
+ SmallVector<int64_t> instDataA = instLayoutA.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> instDataB = instLayoutB.getEffectiveInstDataAsInt();
+ SmallVector<int64_t> instDataCD = instLayoutCD.getEffectiveInstDataAsInt();
+ assert(instDataA.size() == 2 && instDataB.size() == 2 &&
+ instDataCD.size() == 2 &&
+ "Sg layout creation expects valid 2D inst data");
+
+ std::optional<LayoutRepresentation> layoutDVal = std::nullopt;
+ if (consumerLayout) {
+ SmallVector<int64_t> sgLayoutD =
+ consumerLayout.getEffectiveSgLayoutAsInt();
+ layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+ }
+
+ auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
+ auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
+ auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
+ if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
+ return std::nullopt;
+
+ // Step 2. If the result D layout can be reused for all operands, that
+ // layout is chosen. Otherwise, pick the most balanced subgroup layout
+ // that is valid for A, B and C (if present) operands
+ llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
+ llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
+ layoutsCD.end());
+ std::optional<LayoutRepresentation> bestPick;
+ for (auto &l : layoutsB) {
+ if (setA.contains(l) && setCD.contains(l)) {
+ // Is in (A and B and CD) and matches consumer -> best pick
+ if (layoutDVal.has_value() && l == *layoutDVal) {
+ bestPick = l;
+ break;
+ }
+ // Is in (A and B and CD), balanced layout comes first
+ if (!bestPick)
+ bestPick = l;
+ }
+ }
+ // Step 3. If there is no subgroup layout compatible with A, B and C (if
+ // present) operands, we fail.
+ if (!bestPick)
+ return std::nullopt;
+ SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
+ static_cast<int>(bestPick->second)};
+ SmallVector<int> sgDataA = {
+ static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
+ static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
+ SmallVector<int> sgDataB = {
+ static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
+ static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
+ SmallVector<int> sgDataCD = {
+ static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
+ static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
+
+ auto dpasALayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout),
+ DenseI32ArrayAttr::get(context, sgDataA),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr);
+
+ auto dpasBLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout),
+ DenseI32ArrayAttr::get(context, sgDataB),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr);
+
+ auto dpasCDLayout = xegpu::LayoutAttr::get(
+ context, DenseI32ArrayAttr::get(context, sgLayout),
+ DenseI32ArrayAttr::get(context, sgDataCD),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr);
+ return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
+ } else if (layoutKind == xegpu::LayoutKind::InstData) {
+ const int subgroupSize = uArch->getSubgroupSize();
+ const unsigned dataALen = aTy.getShape().front();
----------------
akroviakov wrote:
It is used in the context of uArch's vector (1D) lengths:
```cpp
const unsigned dataALen = aTy.getShape().front();
auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
const int maxALen =
xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
```
https://github.com/llvm/llvm-project/pull/180937
More information about the Mlir-commits
mailing list