[Mlir-commits] [mlir] [MLIR][XeGPU] Use the `setupDpasLayout` utility for dpas layout propagation (PR #180937)

Artem Kroviakov llvmlistbot at llvm.org
Thu Feb 12 02:27:59 PST 2026


================
@@ -852,4 +852,209 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
 
   return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
                                        srcShape, subgroupSize);
-}
\ No newline at end of file
+}
+
+// This function returns the default lane layout for a given vector type.
+// - `packingSize` means multiple consecutive elements can be accessed together
+// as a single unit.
+// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
+// 1x2xf16 w/o vnni).
+template <typename RankedTy>
+static xegpu::LayoutAttr
+getDefaultLaneLayout(RankedTy ty, const xegpu::uArch::uArch *uArch,
+                     std::optional<unsigned> packingSize = std::nullopt,
+                     bool vnni = false) {
+  // Expecting a 1D or 2D vector.
+  assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
+         "Expected 1D non-vnni or 2D vector.");
+  // Expecting int or float element type.
+  assert(ty.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+
+  auto context = ty.getContext();
+  auto rank = ty.getRank();
+  SmallVector<int> laneLayout(rank, 1);
+  SmallVector<int> laneData(rank, 1);
+  if (packingSize.has_value()) {
+    unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
+    int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
+    laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
+  }
+  laneLayout.back() = uArch->getSubgroupSize();
+  return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+}
+
+// This function returns all layouts for the given sgCount, whose sgData:
+// 1. Evenly divides the wgShape.
+// 2. Is a multiple of instData.
+// Example:
+//   wgShape = [128, 64], instData = [8, 16], sgCount = 32
+// Returns layouts:
+//   [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
+using LayoutRepresentation = std::pair<int64_t, int64_t>;
+SmallVector<LayoutRepresentation> getValidLayouts(ArrayRef<int64_t> wgShape,
+                                                  ArrayRef<int64_t> instData,
+                                                  int64_t sgCount) {
+  SmallVector<LayoutRepresentation> candidates;
+  for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
+    if (sgCount % sgLayout0)
+      continue;
+    int64_t sgLayout1 = sgCount / sgLayout0;
+    int64_t sgData0 = wgShape[0] / sgLayout0;
+    int64_t sgData1 = wgShape[1] / sgLayout1;
+    if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
+        (sgData0 % instData[0] || sgData1 % instData[1]))
+      continue;
+    candidates.emplace_back(sgLayout0, sgLayout1);
+  }
+  // Sort primarily by how balanced they are
+  // (i.e., minimize the absolute difference between the two dimensions), and
+  // secondarily by the first dimension in ascending order.
+  llvm::sort(candidates, [](const LayoutRepresentation &lhs,
+                            const LayoutRepresentation &rhs) {
+    int diffLhs = std::abs(lhs.first - lhs.second);
+    int diffRhs = std::abs(rhs.first - rhs.second);
+    if (diffLhs != diffRhs)
+      return diffLhs < diffRhs;
+    return lhs.first < rhs.first;
+  });
+  return candidates;
+}
+
+/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
+/// The numSg and consumerLayout (optional) are only used by sg layout creation.
+std::optional<
+    std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
+               xegpu::DistributeLayoutAttr>>
+xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
+                       VectorType bTy, VectorType cdTy,
+                       xegpu::DistributeLayoutAttr consumerLayout,
+                       const xegpu::uArch::uArch *uArch, int numSg) {
+  auto context = aTy.getContext();
+  const auto *uArchInstruction =
+      dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+          xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
+    assert(numSg > 0 &&
+           "Number of subgroups must be provided for sg layout creation.");
+    auto instDataLayouts =
+        xegpu::setupDpasLayout(xegpu::LayoutKind::InstData, aTy, bTy, cdTy,
+                               consumerLayout, uArch, numSg);
+    if (!instDataLayouts)
+      return std::nullopt;
+    auto [instLayoutA, instLayoutB, instLayoutCD] = *instDataLayouts;
+    SmallVector<int64_t> instDataA = instLayoutA.getEffectiveInstDataAsInt();
+    SmallVector<int64_t> instDataB = instLayoutB.getEffectiveInstDataAsInt();
+    SmallVector<int64_t> instDataCD = instLayoutCD.getEffectiveInstDataAsInt();
+    assert(instDataA.size() == 2 && instDataB.size() == 2 &&
+           instDataCD.size() == 2 &&
+           "Sg layout creation expects valid 2D inst data");
+
+    std::optional<LayoutRepresentation> layoutDVal = std::nullopt;
+    if (consumerLayout) {
+      SmallVector<int64_t> sgLayoutD =
+          consumerLayout.getEffectiveSgLayoutAsInt();
+      layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+    }
+
+    auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
+    auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
+    auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
+    if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
+      return std::nullopt;
+
+    // Step 2. If the result D layout can be reused for all operands, that
+    // layout is chosen. Otherwise, pick the most balanced subgroup layout
+    // that is valid for A, B and C (if present) operands
+    llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
+    llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
+                                               layoutsCD.end());
+    std::optional<LayoutRepresentation> bestPick;
+    for (auto &l : layoutsB) {
+      if (setA.contains(l) && setCD.contains(l)) {
+        // Is in (A and B and CD) and matches consumer -> best pick
+        if (layoutDVal.has_value() && l == *layoutDVal) {
+          bestPick = l;
+          break;
+        }
+        // Is in (A and B and CD), balanced layout comes first
+        if (!bestPick)
+          bestPick = l;
+      }
+    }
+    // Step 3. If there is no subgroup layout compatible with A, B and C (if
+    // present) operands, we fail.
+    if (!bestPick)
+      return std::nullopt;
+    SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
+                                 static_cast<int>(bestPick->second)};
+    SmallVector<int> sgDataA = {
+        static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
+        static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
+    SmallVector<int> sgDataB = {
+        static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
+        static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
+    SmallVector<int> sgDataCD = {
+        static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
+        static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
+
+    auto dpasALayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, sgLayout),
+        DenseI32ArrayAttr::get(context, sgDataA),
+        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+        /*lane_data =*/nullptr, /*order =*/nullptr);
+
+    auto dpasBLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, sgLayout),
+        DenseI32ArrayAttr::get(context, sgDataB),
+        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+        /*lane_data =*/nullptr, /*order =*/nullptr);
+
+    auto dpasCDLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, sgLayout),
+        DenseI32ArrayAttr::get(context, sgDataCD),
+        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+        /*lane_data =*/nullptr, /*order =*/nullptr);
+    return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
+  } else if (layoutKind == xegpu::LayoutKind::InstData) {
+    const int subgroupSize = uArch->getSubgroupSize();
+    const unsigned dataALen = aTy.getShape().front();
----------------
akroviakov wrote:

It is used in the context of uArch's vector (1D) lengths:

```cpp
   const unsigned dataALen = aTy.getShape().front();
    auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
    const int maxALen =
        xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
```

https://github.com/llvm/llvm-project/pull/180937


More information about the Mlir-commits mailing list