[Mlir-commits] [mlir] [MLIR][XeGPU] Use the `setupDpasLayout` utility for dpas layout propagation (PR #180937)

Artem Kroviakov llvmlistbot at llvm.org
Wed Feb 11 05:31:48 PST 2026


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/180937

>From cb5e7a6d8215dfe9abd31f072def78d7a4692292 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 11 Feb 2026 13:16:22 +0000
Subject: [PATCH] [MLIR][XeGPU] Use the setupLayout utility for dpas
 propagation

---
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |   8 +
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 207 ++++++++++++++++-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 219 +++---------------
 3 files changed, 245 insertions(+), 189 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 182607c22c584..0d5210f07f05a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -161,6 +161,14 @@ DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
                                                   VectorType vectorTy,
                                                   const uArch::uArch *uArch);
 
+/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
+/// The numSg and consumerLayout (optional) are only used by sg layout creation.
+std::optional<std::tuple<DistributeLayoutAttr, DistributeLayoutAttr,
+                         DistributeLayoutAttr>>
+setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
+                VectorType cdTy, DistributeLayoutAttr consumerLayout,
+                const uArch::uArch *uArch, int numSg);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index a4e47fca96d34..83f0180961e65 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -852,4 +852,209 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
 
   return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
                                        srcShape, subgroupSize);
-}
\ No newline at end of file
+}
+
+// This function returns the default lane layout for a given vector type.
+// - `packingSize` means multiple consecutive elements can be accessed together
+// as a single unit.
+// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
+// 1x2xf16 w/o vnni).
+template <typename RankedTy>
+static xegpu::LayoutAttr
+getDefaultLaneLayout(RankedTy ty, const xegpu::uArch::uArch *uArch,
+                     std::optional<unsigned> packingSize = std::nullopt,
+                     bool vnni = false) {
+  // Expecting a 1D or 2D vector.
+  assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
+         "Expected 1D non-vnni or 2D vector.");
+  // Expecting int or float element type.
+  assert(ty.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+
+  auto context = ty.getContext();
+  auto rank = ty.getRank();
+  SmallVector<int> laneLayout(rank, 1);
+  SmallVector<int> laneData(rank, 1);
+  if (packingSize.has_value()) {
+    unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
+    int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
+    laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
+  }
+  laneLayout.back() = uArch->getSubgroupSize();
+  return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+}
+
+// This function returns all layouts for the given sgCount, whose sgData:
+// 1. Evenly divides the wgShape.
+// 2. Is a multiple of instData.
+// Example:
+//   wgShape = [128, 64], instData = [8, 16], sgCount = 32
+// Returns layouts:
+//   [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
+using LayoutRepresentation = std::pair<int64_t, int64_t>;
+SmallVector<LayoutRepresentation> getValidLayouts(ArrayRef<int64_t> wgShape,
+                                                  ArrayRef<int64_t> instData,
+                                                  int64_t sgCount) {
+  SmallVector<LayoutRepresentation> candidates;
+  for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
+    if (sgCount % sgLayout0)
+      continue;
+    int64_t sgLayout1 = sgCount / sgLayout0;
+    int64_t sgData0 = wgShape[0] / sgLayout0;
+    int64_t sgData1 = wgShape[1] / sgLayout1;
+    if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
+        (sgData0 % instData[0] || sgData1 % instData[1]))
+      continue;
+    candidates.emplace_back(sgLayout0, sgLayout1);
+  }
+  // Sort primarily by how balanced they are
+  // (i.e., minimize the absolute difference between the two dimensions), and
+  // secondarily by the first dimension in ascending order.
+  llvm::sort(candidates, [](const LayoutRepresentation &lhs,
+                            const LayoutRepresentation &rhs) {
+    int diffLhs = std::abs(lhs.first - lhs.second);
+    int diffRhs = std::abs(rhs.first - rhs.second);
+    if (diffLhs != diffRhs)
+      return diffLhs < diffRhs;
+    return lhs.first < rhs.first;
+  });
+  return candidates;
+}
+
+/// Sets up the anchor layouts for a dpas operands (A, B, and C/D).
+/// The numSg and consumerLayout (optional) are only used by sg layout creation.
+std::optional<
+    std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
+               xegpu::DistributeLayoutAttr>>
+xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
+                       VectorType bTy, VectorType cdTy,
+                       xegpu::DistributeLayoutAttr consumerLayout,
+                       const xegpu::uArch::uArch *uArch, int numSg) {
+  auto context = aTy.getContext();
+  const auto *uArchInstruction =
+      dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+          xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
+    assert(numSg > 0 &&
+           "Number of subgroups must be provided for sg layout creation.");
+    auto instDataLayouts =
+        xegpu::setupDpasLayout(xegpu::LayoutKind::InstData, aTy, bTy, cdTy,
+                               consumerLayout, uArch, numSg);
+    if (!instDataLayouts)
+      return std::nullopt;
+    auto [instLayoutA, instLayoutB, instLayoutCD] = *instDataLayouts;
+    SmallVector<int64_t> instDataA = instLayoutA.getEffectiveInstDataAsInt();
+    SmallVector<int64_t> instDataB = instLayoutB.getEffectiveInstDataAsInt();
+    SmallVector<int64_t> instDataCD = instLayoutCD.getEffectiveInstDataAsInt();
+    assert(instDataA.size() == 2 && instDataB.size() == 2 &&
+           instDataCD.size() == 2 &&
+           "Sg layout creation expects valid 2D inst data");
+
+    std::optional<LayoutRepresentation> layoutDVal = std::nullopt;
+    if (consumerLayout) {
+      SmallVector<int64_t> sgLayoutD =
+          consumerLayout.getEffectiveSgLayoutAsInt();
+      layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+    }
+
+    auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
+    auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
+    auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
+    if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
+      return std::nullopt;
+
+    // Step 2. If the result D layout can be reused for all operands, that
+    // layout is chosen. Otherwise, pick the most balanced subgroup layout
+    // that is valid for A, B and C (if present) operands
+    llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
+    llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
+                                               layoutsCD.end());
+    std::optional<LayoutRepresentation> bestPick;
+    for (auto &l : layoutsB) {
+      if (setA.contains(l) && setCD.contains(l)) {
+        // Is in (A and B and CD) and matches consumer -> best pick
+        if (layoutDVal.has_value() && l == *layoutDVal) {
+          bestPick = l;
+          break;
+        }
+        // Is in (A and B and CD), balanced layout comes first
+        if (!bestPick)
+          bestPick = l;
+      }
+    }
+    // Step 3. If there is no subgroup layout compatible with A, B and C (if
+    // present) operands, we fail.
+    if (!bestPick)
+      return std::nullopt;
+    SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
+                                 static_cast<int>(bestPick->second)};
+    SmallVector<int> sgDataA = {
+        static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
+        static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
+    SmallVector<int> sgDataB = {
+        static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
+        static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
+    SmallVector<int> sgDataCD = {
+        static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
+        static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
+
+    auto dpasALayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, sgLayout),
+        DenseI32ArrayAttr::get(context, sgDataA),
+        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+        /*lane_data =*/nullptr, /*order =*/nullptr);
+
+    auto dpasBLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, sgLayout),
+        DenseI32ArrayAttr::get(context, sgDataB),
+        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+        /*lane_data =*/nullptr, /*order =*/nullptr);
+
+    auto dpasCDLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, sgLayout),
+        DenseI32ArrayAttr::get(context, sgDataCD),
+        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+        /*lane_data =*/nullptr, /*order =*/nullptr);
+    return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
+  } else if (layoutKind == xegpu::LayoutKind::InstData) {
+    const int subgroupSize = uArch->getSubgroupSize();
+    const unsigned dataALen = aTy.getShape().front();
+    auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
+    const int maxALen =
+        xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+
+    const unsigned dataBLen = bTy.getShape().back();
+    auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
+    const int maxBLen =
+        xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+
+    auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
+    const int maxCLen =
+        xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
+    if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
+      return std::nullopt;
+
+    SmallVector<int> instDataA(aTy.getRank(), 1);
+    instDataA[aTy.getRank() - 2] = maxALen;
+    instDataA[aTy.getRank() - 1] = subgroupSize;
+    SmallVector<int> instDataB(bTy.getRank(), 1);
+    instDataB[bTy.getRank() - 2] = subgroupSize;
+    instDataB[bTy.getRank() - 1] = maxBLen;
+    SmallVector<int> instDataCD(cdTy.getRank(), 1);
+    instDataCD[cdTy.getRank() - 2] = maxALen;
+    instDataCD[cdTy.getRank() - 1] = maxCLen;
+    return std::make_tuple(xegpu::LayoutAttr::get(context, instDataA),
+                           xegpu::LayoutAttr::get(context, instDataB),
+                           xegpu::LayoutAttr::get(context, instDataCD));
+  } else if (layoutKind == xegpu::LayoutKind::Lane) {
+    auto aLayout = getDefaultLaneLayout(
+        aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
+    auto bLayout = getDefaultLaneLayout(
+        bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
+    auto cdLayout = getDefaultLaneLayout(
+        cdTy, uArch, uArchInstruction->getPackedFormatBitSizeB());
+    return std::make_tuple(aLayout, bLayout, cdLayout);
+  }
+  return std::nullopt;
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index ccfab7350e351..907a2c9cced90 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -308,33 +308,6 @@ static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
       ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
 }
 
-/// Helper Function to get the expected layouts for DPAS operands. `lane_data`
-/// is set according to the following criteria:
-/// * For A operand, the data must be packed in minimum
-/// `packedSizeInBitsForDefault`
-/// * For B operand, the data must be packed in minimum
-/// `packedSizeInBitsForDpasB`
-static LayoutInfo
-getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
-                                const xegpu::uArch::uArch *uArch,
-                                unsigned packingSize) {
-  Type elementTy = vectorTy.getElementType();
-  assert(elementTy.isIntOrFloat() &&
-         "Expected int or float type in DPAS operands");
-  SmallVector<int32_t, 2> layout({1, uArch->getSubgroupSize()});
-  // For B operand, data must be packed in minimum `packedDpasBSizeInBits` and
-  // must have the VNNI format.
-  if (operandNum == 1 && elementTy.getIntOrFloatBitWidth() < packingSize) {
-    SmallVector<int32_t, 2> data(
-        {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
-         1});
-    return LayoutInfo(
-        xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
-  }
-  // Otherwise, return the default layout for the vector type.
-  return getSIMTLayoutInfoBlockIO(vectorTy, uArch, packingSize);
-}
-
 //===----------------------------------------------------------------------===//
 // LayoutInfoPropagation
 //===----------------------------------------------------------------------===//
@@ -765,180 +738,50 @@ void LayoutInfoPropagation::visitDpasOp(
     dpasBLayout = LayoutInfo(anchorLayoutB);
     dpasCDLayout = LayoutInfo(anchorLayoutCD);
   } else {
+    auto uArch = getUArch(getChipStr(dpas).value_or(""));
     VectorType aTy = dpas.getLhsType();
     VectorType bTy = dpas.getRhsType();
-    VectorType cTy;
-    const bool hasAcc = operands.size() > 2;
-    if (hasAcc)
-      cTy = dpas.getAccType();
-
-    auto uArch = getUArch(getChipStr(dpas).value_or(""));
-    const int subgroupSize = uArch->getSubgroupSize();
-    const auto *uArchInstruction =
-        dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
-            xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
-
-    const unsigned dataALen = aTy.getShape().front();
-    auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
-    const int maxALen =
-        xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
-    if (maxALen == -1)
-      dpas.emitWarning(
-          "No suitable instruction multiple found for the given shape.");
-
-    const unsigned dataBLen = bTy.getShape().back();
-    auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
-
-    const int maxBLen =
-        xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
-
-    if (maxBLen == -1)
-      dpas.emitWarning(
-          "No suitable instruction multiple found for the given shape.");
-    SmallVector<int> instDataA = {maxALen, subgroupSize};
-    SmallVector<int> instDataB = {subgroupSize, maxBLen};
-    SmallVector<int> instDataCD;
-    if (hasAcc) {
-      const unsigned dataCLen = bTy.getShape().back();
-      auto supportedCLen =
-          uArchInstruction->getSupportedN(cTy.getElementType());
-      const int maxCLen =
-          xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
-      if (maxCLen == -1) {
-        dpas.emitWarning(
-            "No suitable instruction multiple found for the given shape.");
-        return;
-      }
-      instDataCD = {maxALen, maxCLen};
-    }
-    if (layoutKind == xegpu::LayoutKind::InstData) {
-      dpasALayout =
-          LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
-      dpasBLayout =
-          LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
-      if (hasAcc) {
-        dpasCDLayout =
-            LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
-      }
-    } else if (layoutKind == xegpu::LayoutKind::Lane) {
-      dpasALayout = getSIMTLayoutInfoForDPASOperand(
-          aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
-      dpasBLayout = getSIMTLayoutInfoForDPASOperand(
-          bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
-      if (hasAcc) {
-        dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
-            cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
-      }
-    } else { // Subgroup
-      auto numSgOrErr = getNumSg(dpas, subgroupSize);
+    VectorType cdTy = dpas.getResultType();
+    LayoutInfo consumerLayout = results[0]->getValue();
+    if (!consumerLayout.isAssigned())
+      dpas.emitWarning("No consumer layout was found for the DPAS result.");
+
+    auto consumerLayoutAttr =
+        dyn_cast_if_present<xegpu::DistributeLayoutAttr>(consumerLayout.get());
+    xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
+        requiredBLayout;
+
+    int numSg = 0;
+    if (layoutKind == xegpu::LayoutKind::Subgroup) {
+      auto numSgOrErr = getNumSg(dpas, uArch->getSubgroupSize());
       if (failed(numSgOrErr)) {
         dpas.emitWarning(
             "Unable to determine the number of subgroups for the operation.");
         return;
       }
+      numSg = numSgOrErr.value();
+    }
+    auto layouts = xegpu::setupDpasLayout(layoutKind, aTy, bTy, cdTy,
+                                          consumerLayoutAttr, uArch, numSg);
+    if (!layouts.has_value()) {
+      dpas.emitWarning(
+          "Failed to determine required layouts for DPAS operands.");
+      return;
+    }
 
-      // Step 1. Get all valid layouts for A, B, and C operands.
-      // All operands must have at least one valid subgroup layout.
-      LayoutInfo layoutD = results[0]->getValue();
-      SmallVector<int> sgLayoutD = layoutD.getSgLayout();
-      assert(!sgLayoutD.empty() && "Expected layout for DPAS result.");
-      auto layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
-
-      auto layoutsA =
-          getValidLayouts(aTy.getShape(), instDataA, numSgOrErr.value());
-      auto layoutsB =
-          getValidLayouts(bTy.getShape(), instDataB, numSgOrErr.value());
-      SmallVector<std::pair<int, int>> layoutsC;
-      if (hasAcc)
-        layoutsC =
-            getValidLayouts(cTy.getShape(), instDataCD, numSgOrErr.value());
-
-      if (layoutsA.empty() || layoutsB.empty() ||
-          (hasAcc && layoutsC.empty())) {
-        dpas.emitWarning(
-            "Unable to determine suitable subgroup layout for A/B/C matrices.");
-        return;
-      }
-
-      // Step 2. If the result D layout can be reused for all operands, that
-      // layout is chosen. Otherwise, pick the most balanced subgroup layout
-      // that is valid for A, B and C (if present) operands
-      llvm::DenseSet<std::pair<int, int>> setA(layoutsA.begin(),
-                                               layoutsA.end());
-      llvm::DenseSet<std::pair<int, int>> setC;
-      if (hasAcc)
-        setC = llvm::DenseSet<std::pair<int, int>>(layoutsC.begin(),
-                                                   layoutsC.end());
-      std::optional<std::pair<int, int>> bestPick;
-      for (auto &l : layoutsB) {
-        if (setA.contains(l)) {
-          if (hasAcc && !setC.contains(l))
-            continue;
-          // Is in (A and B and C) and matches D -> best pick
-          if (l == layoutDVal) {
-            bestPick = l;
-            break;
-          }
-          // Is in (A and B and C), balanced layout comes first
-          if (!bestPick)
-            bestPick = l;
-        }
-      }
-      // Step 3. If there is no subgroup layout compatible with A, B and C (if
-      // present) operands, we fail.
-      SmallVector<int> sgLayout;
-      if (bestPick) {
-        sgLayout = {bestPick->first, bestPick->second};
-      } else {
-        dpas.emitWarning("Unable to find common subgroup layout for matrices.");
-        return;
-      }
-      SmallVector<int> sgDataA = {
-          static_cast<int>(aTy.getShape()[0]) / sgLayout[0],
-          static_cast<int>(aTy.getShape()[1]) / sgLayout[1]};
-      SmallVector<int> sgDataB = {
-          static_cast<int>(bTy.getShape()[0]) / sgLayout[0],
-          static_cast<int>(bTy.getShape()[1]) / sgLayout[1]};
-      SmallVector<int> sgDataC;
-      if (hasAcc)
-        sgDataC = {
-            static_cast<int>(dpas.getResultType().getShape()[0]) / sgLayout[0],
-            static_cast<int>(dpas.getResultType().getShape()[1]) / sgLayout[1]};
-
-      dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
-          aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayout),
-          DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
-          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
-          /*lane_data =*/nullptr, /*order =*/nullptr));
+    std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
 
-      dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
-          bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayout),
-          DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
-          /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
-          /*lane_data =*/nullptr, /*order =*/nullptr));
-      if (hasAcc) {
-        dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
-            cTy.getContext(),
-            DenseI32ArrayAttr::get(cTy.getContext(), sgLayout),
-            DenseI32ArrayAttr::get(cTy.getContext(), sgDataC),
-            /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
-            /*lane_data =*/nullptr, /*order =*/nullptr));
-      }
-    }
-    dpas.setLayoutAAttr(
-        dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
-    dpas.setLayoutBAttr(
-        dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
-    if (hasAcc)
-      dpas.setLayoutCdAttr(
-          dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
+    dpas.setLayoutAAttr(requiredALayout);
+    dpas.setLayoutBAttr(requiredBLayout);
+    dpas.setLayoutCdAttr(requiredCDLayoutAttr);
+    dpasALayout = LayoutInfo(requiredALayout);
+    dpasBLayout = LayoutInfo(requiredBLayout);
+    dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
   }
-
   propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
   propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
-  if (operands.size() > 2) {
+  if (operands.size() > 2)
     propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
-  }
 }
 
 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.



More information about the Mlir-commits mailing list