[Mlir-commits] [mlir] 6d362a9 - [MLIR][XeGPU] Handle `index` element type in the layout propagation (#184322)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 11 08:26:12 PDT 2026
Author: Artem Kroviakov
Date: 2026-03-11T15:25:44Z
New Revision: 6d362a9210a27a7d7572a5b7bf6c9d69115578c2
URL: https://github.com/llvm/llvm-project/commit/6d362a9210a27a7d7572a5b7bf6c9d69115578c2
DIFF: https://github.com/llvm/llvm-project/commit/6d362a9210a27a7d7572a5b7bf6c9d69115578c2.diff
LOG: [MLIR][XeGPU] Handle `index` element type in the layout propagation (#184322)
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 6d4a568f614bd..4bee1752b271e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -55,7 +55,11 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
Option<
"layoutKind", "layout-kind", "std::string",
/*default=*/"\"lane\"",
- "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">
+ "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">,
+ Option<
+ "indexBitWidth", "index-bitwidth", "unsigned",
+ /*default=*/"64",
+ "Vectors of `index` type should also be distributable, inst-data and lower levels need to know the index size.">
];
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 2ae0ef3ae852d..55b18d4a19c55 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -34,7 +34,8 @@ class TensorDescType;
namespace xegpu {
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
- LayoutKind layoutKind, bool printOnly = false);
+ LayoutKind layoutKind, unsigned indexBitWidth,
+ bool printOnly = false);
LogicalResult resolveLayoutConflicts(Operation *target);
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 2c0346f4b2d56..fbb7bb8aeb4bc 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -62,6 +62,7 @@ void buildPreGPUCommonPassPipeline(
void buildGPUPassPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
xegpu::XeGPUPropagateLayoutOptions laneLayoutOptions;
+ laneLayoutOptions.indexBitWidth = options.use64bitIndex ? 64 : 32;
laneLayoutOptions.layoutKind = "lane";
pm.addNestedPass<ModuleOp>(createCSEPass());
if (options.xegpuOpLevel == "workgroup") {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index b3fb00f35b167..8bf0f2aca60c5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -321,6 +321,7 @@ class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
xegpu::LayoutKind layoutKind;
+ unsigned indexBitWidth;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -396,9 +397,9 @@ class LayoutInfoPropagation
public:
LayoutInfoPropagation(DataFlowSolver &solver,
SymbolTableCollection &symbolTable,
- xegpu::LayoutKind layoutKind)
+ xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
- layoutKind(layoutKind) {}
+ layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
LogicalResult
@@ -1179,11 +1180,12 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
- RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+ RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
+ unsigned indexBitWidth)
: target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
- solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
+ solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
(void)solver.initializeAndRun(op);
}
@@ -1573,8 +1575,9 @@ struct XeGPUPropagateLayoutPass final
} // namespace
LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
- LayoutKind layoutKind, bool printOnly) {
- RunLayoutInfoPropagation analysis(target, layoutKind);
+ LayoutKind layoutKind,
+ unsigned indexBitWidth, bool printOnly) {
+ RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
@@ -1657,7 +1660,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
}
OpBuilder builder(&getContext());
if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
- this->printOnly))) {
+ this->indexBitWidth, this->printOnly))) {
signalPassFailure();
return;
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index b9b38fc1ce88b..ffd091b5154b3 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -377,7 +377,7 @@ struct TestXeGPUPropagateLayouts
signalPassFailure();
return;
}
- if (failed(xegpu::propagateLayouts(builder, getOperation(), kind))) {
+ if (failed(xegpu::propagateLayouts(builder, getOperation(), kind, 32))) {
signalPassFailure();
}
}
More information about the Mlir-commits
mailing list