[Mlir-commits] [mlir] [MLIR][XeGPU] Add simple rank-based sg layout creation (PR #172867)
Artem Kroviakov
llvmlistbot at llvm.org
Thu Jan 22 02:52:59 PST 2026
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/172867
>From e6ac0d92c57b0ae1df5c0f25ef53b0a3a4925170 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 18 Dec 2025 16:25:37 +0000
Subject: [PATCH 1/3] [MLIR][XeGPU] Add simple rank-based sg layout creation
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 164 ++++++++++++++++--
.../XeGPU/propagate-layout-subgroup.mlir | 77 ++++++++
2 files changed, 225 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 1341fc21e7fd4..1215f567d2cd9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -531,6 +531,55 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
return false;
}
+FailureOr<std::pair<SmallVector<int>, SmallVector<int>>>
+chooseLayout(llvm::ArrayRef<int64_t> wgShape, int64_t sgCount) {
+ const size_t rank = wgShape.size();
+
+ // Step 1. Factorize sgCount into prime factors.
+ SmallVector<int> layout;
+ int64_t temp = sgCount;
+ for (int64_t i = 2; i * i <= temp; ++i) {
+ while (temp % i == 0) {
+ layout.push_back(i);
+ temp /= i;
+ }
+ }
+ if (temp > 1)
+ layout.push_back(temp);
+
+ if (layout.size() < rank)
+ return failure();
+
+ // Step 2. Fuse two smallest factors until we have `rank` factors.
+ while (layout.size() > rank) {
+ std::sort(layout.begin(), layout.end());
+ int64_t a = layout[0];
+ int64_t b = layout[1];
+ layout.erase(layout.begin());
+ layout[0] = a * b;
+ }
+
+ SmallVector<int> data;
+ for (auto [i, dim] : llvm::enumerate(layout)) {
+ if (wgShape[i] % dim != 0)
+ return failure();
+ data.push_back(wgShape[i] / dim);
+ }
+ return std::make_pair(layout, data);
+}
+
+FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
+ // Oblivious to workitem layout, the total count matters.
+ auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
+ if (!gpuFunc)
+ return failure();
+ auto knownBlockSize = gpuFunc.getKnownBlockSize();
+ if (!knownBlockSize.has_value())
+ return failure();
+ const int flatBlockSize = llvm::product_of(knownBlockSize.value());
+ return flatBlockSize / sgSize;
+}
+
void LayoutInfoPropagation::visitPrefetchNdOp(
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
@@ -746,30 +795,89 @@ void LayoutInfoPropagation::visitDpasOp(
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
dpasBLayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
- } else {
+ } else if (layoutKind == LayoutKind::Lane) {
dpasALayout = getSIMTLayoutInfoForDPASOperand(
aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
dpasBLayout = getSIMTLayoutInfoForDPASOperand(
bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ } else {
+ auto numSgOrErr = getNumSg(dpas, subgroupSize);
+ if (failed(numSgOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine the number of subgroups for the operation.");
+ return;
+ }
+ auto layoutDataAOrErr = chooseLayout(aTy.getShape(), numSgOrErr.value());
+ if (failed(layoutDataAOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine suitable subgroup layout and data for A.");
+ return;
+ }
+ auto [sgLayoutA, sgDataA] = layoutDataAOrErr.value();
+
+ dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
+ aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
+ DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+
+ auto layoutDataBOrErr = chooseLayout(bTy.getShape(), numSgOrErr.value());
+ if (failed(layoutDataBOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine suitable subgroup layout and data for B.");
+ return;
+ }
+ auto [sgLayoutB, sgDataB] = layoutDataBOrErr.value();
+
+ dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
+ bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
+ DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
}
if (operands.size() > 2) {
VectorType cTy = dpas.getAccType();
+ const unsigned dataCLen = bTy.getShape().back();
+ auto supportedCLen =
+ uArchInstruction->getSupportedN(bTy.getElementType());
+ const int maxCLen =
+ xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxCLen == -1) {
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ return;
+ }
+ SmallVector<int> instDataCD = {maxALen, maxCLen};
if (layoutKind == LayoutKind::InstData) {
- const unsigned dataCLen = bTy.getShape().back();
- auto supportedCLen =
- uArchInstruction->getSupportedN(bTy.getElementType());
- const int maxCLen = xegpu::getLargestDivisor(
- dataCLen, ArrayRef<unsigned>(supportedCLen));
- if (maxCLen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataC = {maxALen, maxCLen};
dpasCDLayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
- } else
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
+ } else if (layoutKind == LayoutKind::Lane) {
dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ } else {
+ auto numSgOrErr = getNumSg(dpas, subgroupSize);
+ if (failed(numSgOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine the number of subgroups for the operation.");
+ return;
+ }
+ auto layoutDataAOrErr =
+ chooseLayout(cTy.getShape(), numSgOrErr.value());
+ if (failed(layoutDataAOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine suitable subgroup layout and data for A.");
+ return;
+ }
+ auto [sgLayoutCD, sgDataCD] = layoutDataAOrErr.value();
+
+ dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
+ cTy.getContext(),
+ DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutCD),
+ DenseI32ArrayAttr::get(cTy.getContext(), sgDataCD),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+ }
dpas.setLayoutCdAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
@@ -828,10 +936,34 @@ void LayoutInfoPropagation::visitStoreNdOp(
if (layoutKind == LayoutKind::InstData)
storeLayout =
LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
- else
- storeLayout = getSIMTLayoutInforForBlockIO(
- store.getValueType(), uArch,
- uArchInstruction->getPackedFormatBitSize());
+ else if (layoutKind == LayoutKind::Lane)
+ storeLayout =
+ getSIMTLayoutInforForBlockIO(store.getValueType(), uArch,
+ uArchInstruction->getPackedFormatBitSize());
+ else { // LayoutKind::Subgroup
+ auto sgSize = uArch->getSubgroupSize();
+ auto numSgOrErr = getNumSg(store, sgSize);
+ if (failed(numSgOrErr)) {
+ store.emitWarning(
+ "Unable to determine the number of subgroups for the operation.");
+ return;
+ }
+ auto layoutDataOrErr =
+ chooseLayout(dataTy.getShape(), numSgOrErr.value());
+ if (failed(layoutDataOrErr)) {
+ store.emitWarning(
+ "Unable to determine suitable subgroup layout and data.");
+ return;
+ }
+ auto [sgLayout, sgData] = layoutDataOrErr.value();
+
+ storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
+ dataTy.getContext(),
+ DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
+ DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+ }
store.setLayoutAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 092a4cf442782..b6dfbe79e2712 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -49,3 +49,80 @@ gpu.module @test {
return
}
}
+
+// -----
+gpu.module @test {
+ // CHECK-LABEL: vector_transpose
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ // CHECK-SAME: %[[ARG_1:.*]]: memref<128x256xf32>
+ gpu.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) kernel attributes
+ {known_block_size = array<i32: 1, 32, 16>} {
+ // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>>
+ // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>}>
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>} :
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16]>> -> vector<256x128xf32>
+
+ // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
+ // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>} : vector<256x128xf32> to vector<128x256xf32>
+
+ // CHECK: xegpu.store_nd %[[TRANSPOSED]], %[[TDESC_ST]][0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>}> : vector<128x256xf32>,
+ // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32>
+ %tdesc1 = xegpu.create_nd_tdesc %src1 : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32>
+ %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32> -> vector<256x128xf32>
+ %trans = vector.transpose %load, [1, 0] : vector<256x128xf32> to vector<128x256xf32>
+ xegpu.store_nd %trans, %tdesc1[0, 0] : vector<128x256xf32>, !xegpu.tensor_desc<128x256xf32>
+ gpu.return
+ }
+}
+
+// -----
+gpu.module @test {
+ // CHECK-LABEL: dpas
+ // CHECK-SAME: %[[A_MEMREF:.*]]: memref<128x128xf16>, %[[B_MEMREF:.*]]: memref<128x128xf16>
+ // CHECK-SAME: %[[CD_MEMREF:.*]]: memref<128x128xf32>
+ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>, %d: memref<128x128xf32>) kernel attributes
+ {known_block_size = array<i32: 1, 64, 16>} {
+ // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[A_MEMREF]] : memref<128x128xf16> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+ // CHECK: %[[A_LOADED:.*]] = xegpu.load_nd %[[TDESC_A]]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+
+ // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[B_MEMREF]] : memref<128x128xf16> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+ // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+
+ // CHECK: %[[DPAS_RES:.*]] = xegpu.dpas %[[A_LOADED]], %[[B_LOADED]]
+ // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
+ // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
+ // CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
+ // CHECK-SAME: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+
+ // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[CD_MEMREF]] : memref<128x128xf32> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+ // CHECK: xegpu.store_nd %[[DPAS_RES]], %[[TDESC_ST]][0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}> :
+ // CHECK-SAME: vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+
+ %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16>
+ %load_a = xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<128x128xf16> -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16>
+ %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<128x128xf16> -> vector<128x128xf16>
+ %dpas = xegpu.dpas %load_a, %load_b : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+ %tdesc_cd = xegpu.create_nd_tdesc %d : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32>
+ xegpu.store_nd %dpas, %tdesc_cd[0, 0] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32>
+ gpu.return
+ }
+}
>From b4d6bc31428ca310902110c7c9de1c9122881a70 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 19 Dec 2025 13:07:25 +0000
Subject: [PATCH 2/3] Adjust dpas propagation
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 40 +++++++++----------
.../XeGPU/propagate-layout-subgroup.mlir | 18 ++++-----
2 files changed, 27 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 1215f567d2cd9..6b99f2cd43c71 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -800,37 +800,36 @@ void LayoutInfoPropagation::visitDpasOp(
aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
dpasBLayout = getSIMTLayoutInfoForDPASOperand(
bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
- } else {
+ } else { // Subgroup
auto numSgOrErr = getNumSg(dpas, subgroupSize);
if (failed(numSgOrErr)) {
dpas.emitWarning(
"Unable to determine the number of subgroups for the operation.");
return;
}
- auto layoutDataAOrErr = chooseLayout(aTy.getShape(), numSgOrErr.value());
- if (failed(layoutDataAOrErr)) {
+ auto layoutDataCDOrErr =
+ chooseLayout(dpas.getResultType().getShape(), numSgOrErr.value());
+ if (failed(layoutDataCDOrErr)) {
dpas.emitWarning(
- "Unable to determine suitable subgroup layout and data for A.");
+ "Unable to determine suitable subgroup layout and data for C/D.");
return;
}
- auto [sgLayoutA, sgDataA] = layoutDataAOrErr.value();
+ auto [sgLayoutCD, sgDataCD] = layoutDataCDOrErr.value();
+ auto sgDataA = sgDataCD;
+ sgDataA[1] = aTy.getShape()[1];
+ auto sgDataB = sgDataCD;
+ sgDataB[0] = bTy.getShape()[0];
dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
- aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
+ aTy.getContext(),
+ DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutCD),
DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
/*lane_data =*/nullptr, /*order =*/nullptr));
- auto layoutDataBOrErr = chooseLayout(bTy.getShape(), numSgOrErr.value());
- if (failed(layoutDataBOrErr)) {
- dpas.emitWarning(
- "Unable to determine suitable subgroup layout and data for B.");
- return;
- }
- auto [sgLayoutB, sgDataB] = layoutDataBOrErr.value();
-
dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
- bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
+ bTy.getContext(),
+ DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutCD),
DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
/*lane_data =*/nullptr, /*order =*/nullptr));
@@ -862,15 +861,14 @@ void LayoutInfoPropagation::visitDpasOp(
"Unable to determine the number of subgroups for the operation.");
return;
}
- auto layoutDataAOrErr =
- chooseLayout(cTy.getShape(), numSgOrErr.value());
- if (failed(layoutDataAOrErr)) {
+ auto layoutDataCDOrErr =
+ chooseLayout(dpas.getResultType().getShape(), numSgOrErr.value());
+ if (failed(layoutDataCDOrErr)) {
dpas.emitWarning(
- "Unable to determine suitable subgroup layout and data for A.");
+ "Unable to determine suitable subgroup layout and data for C/D.");
return;
}
- auto [sgLayoutCD, sgDataCD] = layoutDataAOrErr.value();
-
+ auto [sgLayoutCD, sgDataCD] = layoutDataCDOrErr.value();
dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
cTy.getContext(),
DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutCD),
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index b6dfbe79e2712..13f92f12331a2 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -89,23 +89,21 @@ gpu.module @test {
gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>, %d: memref<128x128xf32>) kernel attributes
{known_block_size = array<i32: 1, 64, 16>} {
// CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[A_MEMREF]] : memref<128x128xf16> ->
- // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>>
// CHECK: %[[A_LOADED:.*]] = xegpu.load_nd %[[TDESC_A]]
- // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
- // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>}>
+ // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>> -> vector<128x128xf16>
// CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[B_MEMREF]] : memref<128x128xf16> ->
- // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>>
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>>
- // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>}>
- // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
- // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>> -> vector<128x128xf16>
+ // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>}>
+ // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>> -> vector<128x128xf16>
// CHECK: %[[DPAS_RES:.*]] = xegpu.dpas %[[A_LOADED]], %[[B_LOADED]]
- // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
- // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>,
+ // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 128]>,
+ // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [16, 4], sg_data = [128, 32]>,
// CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [16, 4], sg_data = [8, 32]>} :
// CHECK-SAME: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
>From a0ea277f2a65b69e6fc8b8f5929d6bfd3e162d39 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Thu, 22 Jan 2026 10:52:38 +0000
Subject: [PATCH 3/3] Try to pick a common layout if possible
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 212 ++++++++++++------
1 file changed, 148 insertions(+), 64 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 6b99f2cd43c71..15f189a7a77d3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -531,6 +531,38 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
return false;
}
+SmallVector<std::pair<int, int>> getValidLayouts(ArrayRef<int64_t> wgShape,
+ ArrayRef<int> instData,
+ int64_t sgCount) {
+ SmallVector<std::pair<int, int>> candidates;
+ // Find valid multiples of instData
+ for (int sgLayout0 = 1; sgLayout0 <= sgCount; ++sgLayout0) {
+ if (sgCount % sgLayout0)
+ continue;
+ int sgLayout1 = sgCount / sgLayout0;
+ int sgData0 = wgShape[0] / sgLayout0;
+ int sgData1 = wgShape[1] / sgLayout1;
+ // Check divisibility and instruction atomic alignment
+ if ((wgShape[0] % sgLayout0 || wgShape[1] % sgLayout1) ||
+ (sgData0 % instData[0] || sgData1 % instData[1]))
+ continue;
+
+ candidates.emplace_back(sgLayout0, sgLayout1);
+ }
+ // Sort primarily by how balanced they are
+ // (i.e., minimize the absolute difference between the two dimensions), and
+ // secondarily by the first dimension in ascending order.
+ llvm::sort(candidates, [](const std::pair<int, int> &lhs,
+ const std::pair<int, int> &rhs) {
+ int64_t diffLhs = std::abs(lhs.first - lhs.second);
+ int64_t diffRhs = std::abs(rhs.first - rhs.second);
+ if (diffLhs != diffRhs)
+ return diffLhs < diffRhs;
+ return lhs.first < rhs.first;
+ });
+ return candidates;
+}
+
FailureOr<std::pair<SmallVector<int>, SmallVector<int>>>
chooseLayout(llvm::ArrayRef<int64_t> wgShape, int64_t sgCount) {
const size_t rank = wgShape.size();
@@ -763,6 +795,10 @@ void LayoutInfoPropagation::visitDpasOp(
VectorType aTy = dpas.getLhsType();
VectorType bTy = dpas.getRhsType();
+ VectorType cTy;
+ const bool hasAcc = operands.size() > 2;
+ if (hasAcc)
+ cTy = dpas.getAccType();
auto uArch = getUArch(getChipStr(dpas).value_or(""));
const int subgroupSize = uArch->getSubgroupSize();
@@ -789,17 +825,38 @@ void LayoutInfoPropagation::visitDpasOp(
"No suitable instruction multiple found for the given shape.");
SmallVector<int> instDataA = {maxALen, subgroupSize};
SmallVector<int> instDataB = {subgroupSize, maxBLen};
-
+ SmallVector<int> instDataCD;
+ if (hasAcc) {
+ const unsigned dataCLen = bTy.getShape().back();
+ auto supportedCLen =
+ uArchInstruction->getSupportedN(cTy.getElementType());
+ const int maxCLen =
+ xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxCLen == -1) {
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ return;
+ }
+ instDataCD = {maxALen, maxCLen};
+ }
if (layoutKind == LayoutKind::InstData) {
dpasALayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
dpasBLayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+ if (hasAcc) {
+ dpasCDLayout =
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
+ }
} else if (layoutKind == LayoutKind::Lane) {
dpasALayout = getSIMTLayoutInfoForDPASOperand(
aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
dpasBLayout = getSIMTLayoutInfoForDPASOperand(
bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ if (hasAcc) {
+ dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
+ cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ }
} else { // Subgroup
auto numSgOrErr = getNumSg(dpas, subgroupSize);
if (failed(numSgOrErr)) {
@@ -807,83 +864,108 @@ void LayoutInfoPropagation::visitDpasOp(
"Unable to determine the number of subgroups for the operation.");
return;
}
- auto layoutDataCDOrErr =
- chooseLayout(dpas.getResultType().getShape(), numSgOrErr.value());
- if (failed(layoutDataCDOrErr)) {
+ // Step 1. Retrieve D layout. Get all valid layouts for A and B
+ LayoutInfo layoutD = results[0]->getValue();
+ SmallVector<int> sgLayoutD = layoutD.getSgLayout();
+ assert(!sgLayoutD.empty() && "Expected layout for DPAS result.");
+ auto layoutDVal = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+
+ auto layoutsA =
+ getValidLayouts(aTy.getShape(), instDataA, numSgOrErr.value());
+ auto layoutsB =
+ getValidLayouts(bTy.getShape(), instDataB, numSgOrErr.value());
+ SmallVector<std::pair<int, int>> layoutsC;
+ if (hasAcc)
+ layoutsC =
+ getValidLayouts(cTy.getShape(), instDataCD, numSgOrErr.value());
+
+ if (layoutsA.empty() || layoutsB.empty() ||
+ (hasAcc && layoutsC.empty())) {
dpas.emitWarning(
- "Unable to determine suitable subgroup layout and data for C/D.");
+ "Unable to determine suitable subgroup layout for A/B/C matrices.");
return;
}
- auto [sgLayoutCD, sgDataCD] = layoutDataCDOrErr.value();
- auto sgDataA = sgDataCD;
- sgDataA[1] = aTy.getShape()[1];
- auto sgDataB = sgDataCD;
- sgDataB[0] = bTy.getShape()[0];
+
+ // Step 2. Find common layouts.
+ // Ideally, we want to find given layout D in all A, B and C candidates.
+
+ // Ensure D layout matches one of C layouts.
+ if (hasAcc && llvm::find(layoutsC, layoutDVal) == layoutsC.end()) {
+ dpas.emitWarning("Subgroup layout for D does not match any valid C "
+ "subgroup layout.");
+ return;
+ }
+ // The best pick is layout D. If not found, we will pick any common layout
+ // between A and B.
+ std::optional<std::pair<int, int>> bestPick;
+ llvm::DenseSet<std::pair<int, int>> setA(layoutsA.begin(),
+ layoutsA.end());
+ SmallVector<std::pair<int, int>> common;
+ for (auto &l : layoutsB) {
+ if (setA.contains(l)) {
+ if (l == layoutDVal) {
+ bestPick = l;
+ break;
+ }
+ common.push_back(l);
+ }
+ }
+ // Step 3. The best pick either matches D or is any common layout between
+ // A and B. If no common layout, warn and pick any valid layout.
+ SmallVector<int> sgLayoutA;
+ SmallVector<int> sgLayoutB;
+ if (!bestPick && !common.empty())
+ bestPick = common[0];
+ if (bestPick) {
+ sgLayoutA = {bestPick->first, bestPick->second};
+ sgLayoutB = sgLayoutA;
+ } else {
+ dpas.emitWarning(
+ "Unable to find common subgroup layout for matrices matching "
+ "layout of result. Picking any valid layout.");
+ sgLayoutA = {layoutsA[0].first, layoutsA[0].second};
+ sgLayoutB = {layoutsB[0].first, layoutsB[0].second};
+ }
+ SmallVector<int> sgDataA = {
+ static_cast<int>(aTy.getShape()[0]) / sgLayoutA[0],
+ static_cast<int>(aTy.getShape()[1]) / sgLayoutA[1]};
+ SmallVector<int> sgDataB = {
+ static_cast<int>(bTy.getShape()[0]) / sgLayoutB[0],
+ static_cast<int>(bTy.getShape()[1]) / sgLayoutB[1]};
+ SmallVector<int> sgDataC;
+ if (hasAcc)
+ sgDataC = {static_cast<int>(dpas.getResultType().getShape()[0]) /
+ sgLayoutD[0],
+ static_cast<int>(dpas.getResultType().getShape()[1]) /
+ sgLayoutD[1]};
dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
- aTy.getContext(),
- DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutCD),
+ aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
/*lane_data =*/nullptr, /*order =*/nullptr));
dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
- bTy.getContext(),
- DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutCD),
+ bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
/*lane_data =*/nullptr, /*order =*/nullptr));
- }
-
- if (operands.size() > 2) {
- VectorType cTy = dpas.getAccType();
- const unsigned dataCLen = bTy.getShape().back();
- auto supportedCLen =
- uArchInstruction->getSupportedN(bTy.getElementType());
- const int maxCLen =
- xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
- if (maxCLen == -1) {
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- return;
- }
- SmallVector<int> instDataCD = {maxALen, maxCLen};
- if (layoutKind == LayoutKind::InstData) {
- dpasCDLayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
- } else if (layoutKind == LayoutKind::Lane) {
- dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
- cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
- } else {
- auto numSgOrErr = getNumSg(dpas, subgroupSize);
- if (failed(numSgOrErr)) {
- dpas.emitWarning(
- "Unable to determine the number of subgroups for the operation.");
- return;
- }
- auto layoutDataCDOrErr =
- chooseLayout(dpas.getResultType().getShape(), numSgOrErr.value());
- if (failed(layoutDataCDOrErr)) {
- dpas.emitWarning(
- "Unable to determine suitable subgroup layout and data for C/D.");
- return;
- }
- auto [sgLayoutCD, sgDataCD] = layoutDataCDOrErr.value();
+ if (hasAcc) {
dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
cTy.getContext(),
- DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutCD),
- DenseI32ArrayAttr::get(cTy.getContext(), sgDataCD),
+ DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutD),
+ DenseI32ArrayAttr::get(cTy.getContext(), sgDataC),
/*inst_data =*/nullptr, /*lane_layout =*/nullptr,
/*lane_data =*/nullptr, /*order =*/nullptr));
}
-
- dpas.setLayoutCdAttr(
- dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
}
dpas.setLayoutAAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(dpasALayout.get()));
dpas.setLayoutBAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(dpasBLayout.get()));
+ if (hasAcc)
+ dpas.setLayoutCdAttr(
+ dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
}
propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
@@ -935,9 +1017,9 @@ void LayoutInfoPropagation::visitStoreNdOp(
storeLayout =
LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
else if (layoutKind == LayoutKind::Lane)
- storeLayout =
- getSIMTLayoutInforForBlockIO(store.getValueType(), uArch,
- uArchInstruction->getPackedFormatBitSize());
+ storeLayout = getSIMTLayoutInforForBlockIO(
+ store.getValueType(), uArch,
+ uArchInstruction->getPackedFormatBitSize());
else { // LayoutKind::Subgroup
auto sgSize = uArch->getSubgroupSize();
auto numSgOrErr = getNumSg(store, sgSize);
@@ -946,15 +1028,17 @@ void LayoutInfoPropagation::visitStoreNdOp(
"Unable to determine the number of subgroups for the operation.");
return;
}
- auto layoutDataOrErr =
- chooseLayout(dataTy.getShape(), numSgOrErr.value());
- if (failed(layoutDataOrErr)) {
+ auto sgLayouts = getValidLayouts(store.getValueType().getShape(),
+ instData, numSgOrErr.value());
+ if (sgLayouts.empty()) {
store.emitWarning(
- "Unable to determine suitable subgroup layout and data.");
+ "Unable to determine suitable subgroup layout for store value.");
return;
}
- auto [sgLayout, sgData] = layoutDataOrErr.value();
-
+ SmallVector<int> sgLayout = {sgLayouts[0].first, sgLayouts[0].second};
+ SmallVector<int> sgData = {
+ static_cast<int>(dataTy.getShape()[0]) / sgLayout[0],
+ static_cast<int>(dataTy.getShape()[1]) / sgLayout[1]};
storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
dataTy.getContext(),
DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
More information about the Mlir-commits
mailing list