[Mlir-commits] [mlir] [MLIR][XeGPU] Consider alignment in dpas sg_layout creation (PR #181141)
Artem Kroviakov
llvmlistbot at llvm.org
Tue Mar 3 03:00:49 PST 2026
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/181141
>From 6ec977bf76a4b63d02390e7eaff7a42f9561fc8c Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 3 Mar 2026 11:00:31 +0000
Subject: [PATCH] [MLIR][XeGPU] Handle `index` element type in the layout
propagation
---
.../mlir/Dialect/XeGPU/Transforms/Passes.td | 6 +++++-
.../XeGPU/Transforms/XeGPULayoutImpl.h | 6 ++++--
.../GPU/Pipelines/GPUToXeVMPipeline.cpp | 1 +
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 7 +++++--
.../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 20 +++++++++++--------
.../XeGPU/propagate-layout-inst-data.mlir | 18 +++++++++++++++++
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 2 +-
7 files changed, 46 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index cb71f19da62f0..4686ec1d81bf1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -64,7 +64,11 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
Option<
"layoutKind", "layout-kind", "std::string",
/*default=*/"\"lane\"",
- "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">
+ "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">,
+ Option<
+ "indexBitWidth", "index-bitwidth", "unsigned",
+ /*default=*/"64",
+ "Vectors of `index` type should also be distributable, inst-data and lower levels need to know the index size.">
];
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 3482d1b9401bb..07dcb35d70f18 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -34,7 +34,8 @@ class TensorDescType;
namespace xegpu {
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
- LayoutKind layoutKind, bool printOnly = false);
+ LayoutKind layoutKind, unsigned indexBitWidth,
+ bool printOnly = false);
LogicalResult resolveLayoutConflicts(Operation *target);
@@ -134,7 +135,8 @@ DistributeLayoutAttr setupBitCastResultLayout(
/// Lane).
DistributeLayoutAttr setupInsertStridedSliceResultLayout(
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
- DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+ DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch,
+ unsigned indexBitWidth);
/// Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index d6a4dae4c5c32..af107ebacf41b 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -62,6 +62,7 @@ void buildPreGPUCommonPassPipeline(
void buildGPUPassPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
xegpu::XeGPUPropagateLayoutOptions laneLayoutOptions;
+ laneLayoutOptions.indexBitWidth = options.use64bitIndex ? 64 : 32;
laneLayoutOptions.layoutKind = "lane";
pm.addNestedPass<ModuleOp>(createCSEPass());
pm.addNestedPass<ModuleOp>(createGpuXeVMAttachTarget());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7aa186bb22224..ed44c6ff00338 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -622,7 +622,7 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
- const xegpu::uArch::uArch *uArch) {
+ const xegpu::uArch::uArch *uArch, unsigned indexBitWidth) {
xegpu::DistributeLayoutAttr requiredResLayout;
auto subgroupSize = uArch->getSubgroupSize();
@@ -640,7 +640,10 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
SmallVector<int> laneData(resShapeSize, 1);
const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
- unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
+ unsigned bitwidth =
+ resVectorTy.getElementType().isIndex()
+ ? indexBitWidth
+ : resVectorTy.getElementType().getIntOrFloatBitWidth();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
int packedDataSize = subgroupSize * packingFactor;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 87835fb191604..1c9f06e92dff2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -321,6 +321,7 @@ class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
xegpu::LayoutKind layoutKind;
+ unsigned indexBitWidth;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -396,9 +397,9 @@ class LayoutInfoPropagation
public:
LayoutInfoPropagation(DataFlowSolver &solver,
SymbolTableCollection &symbolTable,
- xegpu::LayoutKind layoutKind)
+ xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
- layoutKind(layoutKind) {}
+ layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
LogicalResult
@@ -958,7 +959,8 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
- layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
+ layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch,
+ indexBitWidth);
xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
requiredResLayoutAttr);
@@ -1156,11 +1158,12 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
- RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+ RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
+ unsigned indexBitWidth)
: target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
- solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
+ solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
(void)solver.initializeAndRun(op);
}
@@ -1544,8 +1547,9 @@ struct XeGPUPropagateLayoutPass final
} // namespace
LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
- LayoutKind layoutKind, bool printOnly) {
- RunLayoutInfoPropagation analysis(target, layoutKind);
+ LayoutKind layoutKind,
+ unsigned indexBitWidth, bool printOnly) {
+ RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
@@ -1628,7 +1632,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
}
OpBuilder builder(&getContext());
if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
- this->printOnly))) {
+ this->indexBitWidth, this->printOnly))) {
signalPassFailure();
return;
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5dd05e6cb0001..340aa92d2d425 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -334,3 +334,21 @@ gpu.module @test{
return
}
}
+
+
+// -----
+gpu.module @test{
+ // CHECK-LABEL: insert_strided_slice_index_type
+ // CHECK: vector.insert_strided_slice %{{.*}} {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>
+ // CHECK: xegpu.store %{{.*}} <{layout = #xegpu.layout<inst_data = [1, 16]>}>
+ func.func @insert_strided_slice_index_type(%arg0: i64) {
+ %vector_step_2d_placeholder = arith.constant dense<1> : vector<16x16xindex>
+ %vector_step_slice = arith.constant dense<12> : vector<1x16xindex>
+ %v = vector.insert_strided_slice %vector_step_slice, %vector_step_2d_placeholder
+ {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x16xindex> into vector<16x16xindex>
+ %cst = arith.constant dense<true> : vector<16x16xi1>
+ %data = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+ xegpu.store %data, %arg0[%v], %cst : vector<16x16xf16>, i64, vector<16x16xindex>, vector<16x16xi1>
+ return
+ }
+}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index b9b38fc1ce88b..ffd091b5154b3 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -377,7 +377,7 @@ struct TestXeGPUPropagateLayouts
signalPassFailure();
return;
}
- if (failed(xegpu::propagateLayouts(builder, getOperation(), kind))) {
+ if (failed(xegpu::propagateLayouts(builder, getOperation(), kind, 32))) {
signalPassFailure();
}
}
More information about the Mlir-commits
mailing list