[Mlir-commits] [mlir] [MLIR][XeGPU] Use the `setupDpasLayout` utility for dpas layout propagation (PR #180937)
Artem Kroviakov
llvmlistbot at llvm.org
Thu Feb 12 03:04:25 PST 2026
================
@@ -852,4 +852,209 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
srcShape, subgroupSize);
-}
\ No newline at end of file
+}
+
+// This function returns the default lane layout for a given vector type.
+// - `packingSize` means multiple consecutive elements can be accessed together
+// as a single unit.
+// - `vnni` means data packing is column-wise (i.e., 2x1xf16 with vnni vs.
+// 1x2xf16 w/o vnni).
+template <typename RankedTy>
+static xegpu::LayoutAttr
+getDefaultLaneLayout(RankedTy ty, const xegpu::uArch::uArch *uArch,
+ std::optional<unsigned> packingSize = std::nullopt,
+ bool vnni = false) {
+ // Expecting a 1D or 2D vector.
+ assert(((ty.getRank() == 1 && !vnni) || ty.getRank() == 2) &&
+ "Expected 1D non-vnni or 2D vector.");
+ // Expecting int or float element type.
+ assert(ty.getElementType().isIntOrFloat() &&
+ "Expected int or float element type.");
+
+ auto context = ty.getContext();
+ auto rank = ty.getRank();
+ SmallVector<int> laneLayout(rank, 1);
+ SmallVector<int> laneData(rank, 1);
+ if (packingSize.has_value()) {
+ unsigned bitwidth = ty.getElementType().getIntOrFloatBitWidth();
+ int &laneDataPos = vnni ? laneData[rank - 2] : laneData.back();
+ laneDataPos = bitwidth < *packingSize ? *packingSize / bitwidth : 1;
+ }
+ laneLayout.back() = uArch->getSubgroupSize();
+ return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+}
+
+// This function returns all layouts for the given sgCount, whose sgData:
+// 1. Evenly divides the wgShape.
+// 2. Is a multiple of instData.
+// Example:
+// wgShape = [128, 64], instData = [8, 16], sgCount = 32
+// Returns layouts:
+// [(8,4), (16,2)], which correspond to sgData [16,16] and [8,32].
+using LayoutRepresentation = std::pair<int64_t, int64_t>;
+SmallVector<LayoutRepresentation> getValidLayouts(ArrayRef<int64_t> wgShape,
----------------
akroviakov wrote:
fixed
https://github.com/llvm/llvm-project/pull/180937
More information about the Mlir-commits
mailing list