[Mlir-commits] [mlir] [MLIR][XeGPU] Use the `setupDpasLayout` utility for dpas layout propagation (PR #180937)

Artem Kroviakov llvmlistbot at llvm.org
Thu Feb 12 03:04:25 PST 2026


================
@@ -852,4 +852,209 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
 
   return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
                                        srcShape, subgroupSize);
-}
\ No newline at end of file
+}
+
+// This function returns the default lane layout for a given vector type.
+// - `packingSize` means multiple consecutive elements can be accessed together
+// as a single unit.
+// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
+// 1x2xf16 w/o vnni).
+template <typename RankedTy>
+static xegpu::LayoutAttr
+getDefaultLaneLayout(RankedTy ty, const xegpu::uArch::uArch *uArch,
+                     std::optional<unsigned> packingSize = std::nullopt,
+                     bool vnni = false) {
+  // Expecting a 1D or 2D vector.
+  assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
+         "Expected 1D non-vnni or 2D vector.");
+  // Expecting int or float element type.
+  assert(ty.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+
+  auto context = ty.getContext();
+  auto rank = ty.getRank();
+  SmallVector<int> laneLayout(rank, 1);
+  SmallVector<int> laneData(rank, 1);
+  if (packingSize.has_value()) {
+    unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
+    int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
+    laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
+  }
+  laneLayout.back() = uArch->getSubgroupSize();
+  return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+}
+
+// This function returns all layouts for the given sgCount, whose sgData:
+// 1. Evenly divides the wgShape.
+// 2. Is a multiple of instData.
+// Example:
+//   wgShape = [128, 64], instData = [8, 16], sgCount = 32
+// Returns layouts:
+//   [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
+using LayoutRepresentation = std::pair<int64_t, int64_t>;
+SmallVector<LayoutRepresentation> getValidLayouts(ArrayRef<int64_t> wgShape,
----------------
akroviakov wrote:

fixed

https://github.com/llvm/llvm-project/pull/180937


More information about the Mlir-commits mailing list