[Mlir-commits] [mlir] [MLIR][XeGPU] Add sg layout propagation (PR #170879)
Artem Kroviakov
llvmlistbot at llvm.org
Tue Dec 9 07:02:58 PST 2025
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/170879
>From ea527e6a1714b579d1906c68f00e75b4e000de12 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 5 Dec 2025 16:08:04 +0000
Subject: [PATCH 1/2] [MLIR][XeGPU] Add sg layout propagation
---
.../mlir/Dialect/XeGPU/Transforms/Passes.td | 6 +-
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 59 +++++++++++++++++--
.../Dialect/XeGPU/propagate-layout-sg.mlir | 53 +++++++++++++++++
3 files changed, 112 insertions(+), 6 deletions(-)
create mode 100644 mlir/test/Dialect/XeGPU/propagate-layout-sg.mlir
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 0ca58426ecfcb..c682e6fdad1df 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -50,6 +50,10 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
- `lane`
Propagate the `lane_layout` and `lane_data` fields of the layout attribute.
Default values are selected to align with hardware.
+
+ - `sg`
+ Propagate the `sg_layout` and `sg_data` fields of the layout attribute.
+ Default values are selected to align with hardware.
}];
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
"vector::VectorDialect"];
@@ -60,7 +64,7 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
Option<
"layoutKind", "layout-kind", "std::string",
/*default=*/"\"lane\"",
- "Propagate `inst` / `lane` level of xegpu layouts.">
+ "Propagate `sg` / `inst` / `lane` level of xegpu layouts.">
];
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 59a1ad9dbe189..6eadce67f8202 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -53,7 +53,7 @@ using namespace mlir::dataflow;
namespace {
-enum class LayoutKind { Lane, InstData };
+enum class LayoutKind { Lane, InstData, Subgroup };
//===----------------------------------------------------------------------===//
// LayoutInfo
@@ -109,6 +109,12 @@ struct LayoutInfo {
SmallVector<int> getInstData() const;
+ SmallVector<int> getSgLayout() const;
+
+ SmallVector<int> getSgData() const;
+
+ SmallVector<int> getOrder() const;
+
bool isSliceLayout() const {
if (!isAssigned())
return false;
@@ -127,8 +133,6 @@ struct LayoutInfo {
SmallVector<int> LayoutInfo::getLaneLayout() const {
if (!isAssigned())
return {};
- assert(storage.getEffectiveLaneLayoutAsInt().size() &&
- "Expected lane layout to be assigned");
return llvm::map_to_vector(storage.getEffectiveLaneLayoutAsInt(),
[](int64_t val) { return static_cast<int>(val); });
}
@@ -136,8 +140,6 @@ SmallVector<int> LayoutInfo::getLaneLayout() const {
SmallVector<int> LayoutInfo::getLaneData() const {
if (!isAssigned())
return {};
- assert(storage.getEffectiveLaneDataAsInt().size() &&
- "Expected lane data to be assigned");
return llvm::map_to_vector(storage.getEffectiveLaneDataAsInt(),
[](int64_t val) { return static_cast<int>(val); });
}
@@ -149,6 +151,27 @@ SmallVector<int> LayoutInfo::getInstData() const {
[](int64_t val) { return static_cast<int>(val); });
}
+SmallVector<int> LayoutInfo::getSgLayout() const {
+ if (!isAssigned())
+ return {};
+ return llvm::map_to_vector(storage.getEffectiveSgLayoutAsInt(),
+ [](int64_t val) { return static_cast<int>(val); });
+}
+
+SmallVector<int> LayoutInfo::getSgData() const {
+ if (!isAssigned())
+ return {};
+ return llvm::map_to_vector(storage.getEffectiveSgDataAsInt(),
+ [](int64_t val) { return static_cast<int>(val); });
+}
+
+SmallVector<int> LayoutInfo::getOrder() const {
+ if (!isAssigned() || !storage.getOrder())
+ return {};
+ return llvm::map_to_vector(storage.getOrder().asArrayRef(),
+ [](int64_t val) { return static_cast<int>(val); });
+}
+
void LayoutInfo::print(raw_ostream &os) const {
if (isAssigned()) {
os << storage;
@@ -188,6 +211,10 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
SmallVector<int32_t> laneLayout;
SmallVector<int32_t> laneData;
SmallVector<int32_t> instData;
+ SmallVector<int32_t> sgLayout;
+ SmallVector<int32_t> sgData;
+ SmallVector<int32_t> order;
+
for (int64_t idx : permutation) {
if (getLaneLayout().size()) {
laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
@@ -195,13 +222,30 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
}
if (getInstData().size())
instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+ if (getSgData().size()) {
+ sgLayout.push_back(static_cast<int32_t>(getSgLayout()[idx]));
+ sgData.push_back(static_cast<int32_t>(getSgData()[idx]));
+ }
+ if (getOrder().size()) {
+ order.push_back(static_cast<int32_t>(getOrder()[idx]));
+ }
}
+ auto orderAttr = order.size()
+ ? DenseI32ArrayAttr::get(storage.getContext(), order)
+ : nullptr;
xegpu::LayoutAttr layoutAttr;
if (getLaneLayout().size())
layoutAttr =
xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
if (getInstData().size())
layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
+ if (getSgData().size())
+ layoutAttr = xegpu::LayoutAttr::get(
+ storage.getContext(),
+ DenseI32ArrayAttr::get(storage.getContext(), sgLayout),
+ DenseI32ArrayAttr::get(storage.getContext(), sgData),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, orderAttr);
return LayoutInfo(layoutAttr);
}
@@ -487,6 +531,9 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
} else if (layoutKind == LayoutKind::Lane) {
return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
anchorLayout.getEffectiveLaneDataAsInt().empty());
+ } else if (layoutKind == LayoutKind::Subgroup) {
+ return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
+ anchorLayout.getEffectiveSgDataAsInt().empty());
}
return false;
}
@@ -1295,6 +1342,8 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
layoutKind = LayoutKind::Lane;
} else if (this->layoutKind == "inst") {
layoutKind = LayoutKind::InstData;
+ } else if (this->layoutKind == "sg") {
+ layoutKind = LayoutKind::Subgroup;
} else {
getOperation()->emitError("Unsupported layout kind option: " +
this->layoutKind);
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-sg.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-sg.mlir
new file mode 100644
index 0000000000000..5659e9995b22a
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-sg.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=sg" -split-input-file %s | FileCheck %s
+
+gpu.module @test {
+ // CHECK-LABEL: store_nd
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ func.func @store_nd(%src: memref<256x128xf32>) {
+ // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
+ // CHECK-SAME: -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}>
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}
+ // CHECK-SAME: : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ // CHECK-SAME: -> vector<256x128xf32>
+ // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}>
+ // CHECK-SAME: : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32>
+ %load = xegpu.load_nd %tdesc : !xegpu.tensor_desc<256x128xf32> -> vector<256x128xf32>
+ xegpu.store_nd %load, %tdesc {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}
+ : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32>
+ return
+ }
+}
+
+// -----
+
+gpu.module @test {
+ // CHECK-LABEL: vector_transpose
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ // CHECK-SAME: %[[ARG_1:.*]]: memref<128x256xf32>
+ func.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) {
+ // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>>
+ // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>>
+
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>}>
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>} :
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>> -> vector<256x128xf32>
+
+ // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
+ // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
+
+ // CHECK: xegpu.store_nd %[[TRANSPOSED]], %[[TDESC_ST]][0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>}> : vector<128x256xf32>,
+ // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32>
+ %tdesc1 = xegpu.create_nd_tdesc %src1 : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32>
+ %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32> -> vector<256x128xf32>
+ %trans = vector.transpose %load, [1, 0] : vector<256x128xf32> to vector<128x256xf32>
+ xegpu.store_nd %trans, %tdesc1[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>}
+ : vector<128x256xf32>, !xegpu.tensor_desc<128x256xf32>
+ return
+ }
+}
>From 2925b5cbbe6918ef659f87f652542e64fcde87a6 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 9 Dec 2025 15:02:41 +0000
Subject: [PATCH 2/2] Sg layout creation
---
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 189 ++++++++++++++++--
.../Dialect/XeGPU/propagate-layout-sg.mlir | 78 ++++++++
2 files changed, 251 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 6eadce67f8202..254940d0466bc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -538,6 +538,56 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
return false;
}
+FailureOr<std::pair<SmallVector<int>, SmallVector<int>>>
+chooseLayout(llvm::ArrayRef<int64_t> big, llvm::ArrayRef<int> small,
+ const int64_t count) {
+ const int64_t n = big.size();
+ assert(n == small.size());
+ for (int dim = 0; dim < n; ++dim) {
+ if (big[dim] % small[dim])
+ return failure();
+ }
+ // Fill the large shape with smaller ones
+ SmallVector<int> tiles(n);
+ for (int64_t dim = 0; dim < n; ++dim) {
+ tiles[dim] = big[dim] / small[dim];
+ assert(!(tiles[dim] % 2) || tiles[dim] == 1);
+ }
+ // The baseline layout is based on the smallest data
+ int64_t totalTiles = llvm::product_of(tiles);
+ SmallVector<int> bestLayout = tiles;
+ SmallVector<int> bestData{small};
+ // Fit the layout to the given count
+ while (totalTiles > count) {
+ // Stack tiles along the longest layout dim
+ int64_t maxDim = 0;
+ for (int64_t i = 1; i < n; ++i)
+ if (bestLayout[i] > bestLayout[maxDim])
+ maxDim = i;
+ // If the longest dim is 1, cannot divide further
+ if (bestLayout[maxDim] == 1)
+ return failure();
+ // Merge data tiles
+ bestLayout[maxDim] /= 2;
+ bestData[maxDim] *= 2;
+ totalTiles = llvm::product_of(bestLayout);
+ }
+ assert(totalTiles == count);
+ return std::make_pair(bestLayout, bestData);
+}
+
+FailureOr<int64_t> getNumSg(Operation *op, const int sgSize) {
+ // Oblivious to workitem layout, the total count matters.
+ auto gpuFunc = op->getParentOfType<gpu::GPUFuncOp>();
+ if (!gpuFunc)
+ return failure();
+ auto knownBlockSize = gpuFunc.getKnownBlockSize();
+ if (!knownBlockSize.has_value())
+ return failure();
+ const int flatBlockSize = llvm::product_of(knownBlockSize.value());
+ return flatBlockSize / sgSize;
+}
+
void LayoutInfoPropagation::visitPrefetchNdOp(
xegpu::PrefetchNdOp prefetch, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results) {
@@ -582,10 +632,33 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
if (layoutKind == LayoutKind::InstData)
prefetchLayout =
LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
- else
+ else if (layoutKind == LayoutKind::Lane)
prefetchLayout = getDefaultSIMTLayoutInfo(
tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
-
+ else { // LayoutKind::Subgroup
+ auto sgSize = uArch->getSubgroupSize();
+ auto numSgOrErr = getNumSg(prefetch, sgSize);
+ if (failed(numSgOrErr)) {
+ prefetch.emitWarning(
+ "Unable to determine the number of subgroups for the operation.");
+ return;
+ }
+ auto layoutDataOrErr =
+ chooseLayout(tdescTy.getShape(), instData, numSgOrErr.value());
+ if (failed(layoutDataOrErr)) {
+ prefetch.emitWarning(
+ "Unable to determine suitable subgroup layout and data.");
+ return;
+ }
+ auto [sgLayout, sgData] = layoutDataOrErr.value();
+
+ prefetchLayout = LayoutInfo(xegpu::LayoutAttr::get(
+ tdescTy.getContext(),
+ DenseI32ArrayAttr::get(tdescTy.getContext(), sgLayout),
+ DenseI32ArrayAttr::get(tdescTy.getContext(), sgData),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+ }
prefetch.setLayoutAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(prefetchLayout.get()));
}
@@ -741,31 +814,91 @@ void LayoutInfoPropagation::visitDpasOp(
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
dpasBLayout =
LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
- } else {
+ } else if (layoutKind == LayoutKind::Lane) {
dpasALayout = getSIMTLayoutInfoForDPASOperand(
aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
dpasBLayout = getSIMTLayoutInfoForDPASOperand(
bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+ } else {
+ auto numSgOrErr = getNumSg(dpas, subgroupSize);
+ if (failed(numSgOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine the number of subgroups for the operation.");
+ return;
+ }
+ auto layoutDataAOrErr =
+ chooseLayout(aTy.getShape(), instDataA, numSgOrErr.value());
+ if (failed(layoutDataAOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine suitable subgroup layout and data for A.");
+ return;
+ }
+ auto [sgLayoutA, sgDataA] = layoutDataAOrErr.value();
+
+ dpasALayout = LayoutInfo(xegpu::LayoutAttr::get(
+ aTy.getContext(), DenseI32ArrayAttr::get(aTy.getContext(), sgLayoutA),
+ DenseI32ArrayAttr::get(aTy.getContext(), sgDataA),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+
+ auto layoutDataBOrErr =
+ chooseLayout(bTy.getShape(), instDataB, numSgOrErr.value());
+ if (failed(layoutDataBOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine suitable subgroup layout and data for B.");
+ return;
+ }
+ auto [sgLayoutB, sgDataB] = layoutDataBOrErr.value();
+
+ dpasBLayout = LayoutInfo(xegpu::LayoutAttr::get(
+ bTy.getContext(), DenseI32ArrayAttr::get(bTy.getContext(), sgLayoutB),
+ DenseI32ArrayAttr::get(bTy.getContext(), sgDataB),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
}
if (operands.size() > 2) {
VectorType cTy = dpas.getAccType();
+ const unsigned dataCLen = bTy.getShape().back();
+ auto supportedCLen =
+ uArchInstruction->getSupportedN(bTy.getElementType());
+ const int maxCLen =
+ xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
+ if (maxCLen == -1) {
+ dpas.emitWarning(
+ "No suitable instruction multiple found for the given shape.");
+ return;
+ }
+ SmallVector<int> instDataCD = {maxALen, maxCLen};
if (layoutKind == LayoutKind::InstData) {
- const unsigned dataCLen = bTy.getShape().back();
- auto supportedCLen =
- uArchInstruction->getSupportedN(bTy.getElementType());
- const int maxCLen = xegpu::getLargestDivisor(
- dataCLen, ArrayRef<unsigned>(supportedCLen));
- if (maxCLen == -1)
- dpas.emitWarning(
- "No suitable instruction multiple found for the given shape.");
- SmallVector<int> instDataC = {maxALen, maxCLen};
dpasCDLayout =
- LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
- } else
+ LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataCD));
+ } else if (layoutKind == LayoutKind::Lane) {
dpasCDLayout = getSIMTLayoutInfoForDPASOperand(
cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
-
+ } else {
+ auto numSgOrErr = getNumSg(dpas, subgroupSize);
+ if (failed(numSgOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine the number of subgroups for the operation.");
+ return;
+ }
+ auto layoutDataAOrErr =
+ chooseLayout(cTy.getShape(), instDataCD, numSgOrErr.value());
+ if (failed(layoutDataAOrErr)) {
+ dpas.emitWarning(
+ "Unable to determine suitable subgroup layout and data for A.");
+ return;
+ }
+ auto [sgLayoutCD, sgDataCD] = layoutDataAOrErr.value();
+
+ dpasCDLayout = LayoutInfo(xegpu::LayoutAttr::get(
+ cTy.getContext(),
+ DenseI32ArrayAttr::get(cTy.getContext(), sgLayoutCD),
+ DenseI32ArrayAttr::get(cTy.getContext(), sgDataCD),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+ }
dpas.setLayoutCdAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(dpasCDLayout.get()));
}
@@ -823,10 +956,34 @@ void LayoutInfoPropagation::visitStoreNdOp(
if (layoutKind == LayoutKind::InstData)
storeLayout =
LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
- else
+ else if (layoutKind == LayoutKind::Lane)
storeLayout =
getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
uArchInstruction->getPackedFormatBitSize());
+ else { // LayoutKind::Subgroup
+ auto sgSize = uArch->getSubgroupSize();
+ auto numSgOrErr = getNumSg(store, sgSize);
+ if (failed(numSgOrErr)) {
+ store.emitWarning(
+ "Unable to determine the number of subgroups for the operation.");
+ return;
+ }
+ auto layoutDataOrErr =
+ chooseLayout(dataTy.getShape(), instData, numSgOrErr.value());
+ if (failed(layoutDataOrErr)) {
+ store.emitWarning(
+ "Unable to determine suitable subgroup layout and data.");
+ return;
+ }
+ auto [sgLayout, sgData] = layoutDataOrErr.value();
+
+ storeLayout = LayoutInfo(xegpu::LayoutAttr::get(
+ dataTy.getContext(),
+ DenseI32ArrayAttr::get(dataTy.getContext(), sgLayout),
+ DenseI32ArrayAttr::get(dataTy.getContext(), sgData),
+ /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+ /*lane_data =*/nullptr, /*order =*/nullptr));
+ }
store.setLayoutAttr(
dyn_cast<xegpu::DistributeLayoutAttr>(storeLayout.get()));
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-sg.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-sg.mlir
index 5659e9995b22a..52e739f8ed8b6 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-sg.mlir
@@ -51,3 +51,81 @@ gpu.module @test {
return
}
}
+// -----
+
+gpu.module @test {
+ // CHECK-LABEL: vector_transpose
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ // CHECK-SAME: %[[ARG_1:.*]]: memref<128x256xf32>
+ gpu.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) kernel attributes
+ {known_block_size = array<i32: 1, 32, 16>} {
+ // CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>>
+
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}>
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} :
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>> -> vector<256x128xf32>
+
+ // CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
+ // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>} : vector<256x128xf32> to vector<128x256xf32>
+
+ // CHECK: xegpu.store_nd %[[TRANSPOSED]], %[[TDESC_ST]][0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>}> : vector<128x256xf32>,
+ // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32>
+ %tdesc1 = xegpu.create_nd_tdesc %src1 : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32>
+ %load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32> -> vector<256x128xf32>
+ %trans = vector.transpose %load, [1, 0] : vector<256x128xf32> to vector<128x256xf32>
+ xegpu.store_nd %trans, %tdesc1[0, 0] : vector<128x256xf32>, !xegpu.tensor_desc<128x256xf32>
+ gpu.return
+ }
+}
+
+// -----
+
+gpu.module @test {
+ // CHECK-LABEL: dpas
+ // CHECK-SAME: %[[A_MEMREF:.*]]: memref<128x128xf16>, %[[B_MEMREF:.*]]: memref<128x128xf16>
+ // CHECK-SAME: %[[CD_MEMREF:.*]]: memref<128x128xf32>
+ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>, %d: memref<128x128xf32>) kernel attributes
+ {known_block_size = array<i32: 1, 64, 16>} {
+ // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[A_MEMREF]] : memref<128x128xf16> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+
+ // CHECK: %[[A_LOADED:.*]] = xegpu.load_nd %[[TDESC_A]]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}>
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} :
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf16>
+
+ // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[B_MEMREF]] : memref<128x128xf16> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+
+ // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}>
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} :
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf16>
+
+ // CHECK: %[[DPAS_RES:.*]] = xegpu.dpas %[[A_LOADED]], %[[B_LOADED]]
+ // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>,
+ // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>,
+ // CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} :
+ // CHECK-SAME: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+
+ // CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[CD_MEMREF]] : memref<128x128xf32> ->
+ // CHECK-SAME: !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+
+ // CHECK: xegpu.store_nd %[[DPAS_RES]], %[[TDESC_ST]][0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> :
+ // CHECK-SAME: vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+
+ %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16>
+ %load_a = xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<128x128xf16> -> vector<128x128xf16>
+ %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16> -> !xegpu.tensor_desc<128x128xf16>
+ %load_b = xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<128x128xf16> -> vector<128x128xf16>
+ %dpas = xegpu.dpas %load_a, %load_b : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
+ %tdesc_cd = xegpu.create_nd_tdesc %d : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32>
+ xegpu.store_nd %dpas, %tdesc_cd[0, 0] : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32>
+ gpu.return
+ }
+}
More information about the Mlir-commits
mailing list