[Mlir-commits] [mlir] [MLIR][XeGPU] Handle `index` element type in the layout propagation (PR #184322)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 3 03:14:05 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu
Author: Artem Kroviakov (akroviakov)
<details>
<summary>Changes</summary>
Propagation queries the bitwidth of vector elements for certain ops.
In XeGPU, the offsets of memory operations are vectors of `index` element type, which does not have a bitwidth until we lower it to some llvm type.
If we try to propagate layout through certain ops that apply to offsets vector, we crash due to the inability to extract `index` bitwidth.
The current solution is verbose, but requires no additional caution or input from users, it fully respects the users' index bitwidth option.
The verbosity of the underlying implementation is something to discuss (maybe not in this enabling PR). Especially the potential need to modify multiple signatures of the `setup...()`/`infer...()` functions.
---
Full diff: https://github.com/llvm/llvm-project/pull/184322.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td (+5-1)
- (modified) mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h (+4-2)
- (modified) mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp (+1)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp (+5-2)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (+12-8)
- (modified) mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir (+18)
- (modified) mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index cb71f19da62f0..4686ec1d81bf1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -64,7 +64,11 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
Option<
"layoutKind", "layout-kind", "std::string",
/*default=*/"\"lane\"",
- "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">
+ "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">,
+ Option<
+ "indexBitWidth", "index-bitwidth", "unsigned",
+ /*default=*/"64",
+ "Vectors of `index` type should also be distributable, inst-data and lower levels need to know the index size.">
];
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 3482d1b9401bb..07dcb35d70f18 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -34,7 +34,8 @@ class TensorDescType;
namespace xegpu {
LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
- LayoutKind layoutKind, bool printOnly = false);
+ LayoutKind layoutKind, unsigned indexBitWidth,
+ bool printOnly = false);
LogicalResult resolveLayoutConflicts(Operation *target);
@@ -134,7 +135,8 @@ DistributeLayoutAttr setupBitCastResultLayout(
/// Lane).
DistributeLayoutAttr setupInsertStridedSliceResultLayout(
LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
- DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+ DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch,
+ unsigned indexBitWidth);
/// Sets up the anchor layout for a load gather operation.
DistributeLayoutAttr
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index d6a4dae4c5c32..af107ebacf41b 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -62,6 +62,7 @@ void buildPreGPUCommonPassPipeline(
void buildGPUPassPipeline(OpPassManager &pm,
const mlir::gpu::GPUToXeVMPipelineOptions &options) {
xegpu::XeGPUPropagateLayoutOptions laneLayoutOptions;
+ laneLayoutOptions.indexBitWidth = options.use64bitIndex ? 64 : 32;
laneLayoutOptions.layoutKind = "lane";
pm.addNestedPass<ModuleOp>(createCSEPass());
pm.addNestedPass<ModuleOp>(createGpuXeVMAttachTarget());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7aa186bb22224..ed44c6ff00338 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -622,7 +622,7 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
- const xegpu::uArch::uArch *uArch) {
+ const xegpu::uArch::uArch *uArch, unsigned indexBitWidth) {
xegpu::DistributeLayoutAttr requiredResLayout;
auto subgroupSize = uArch->getSubgroupSize();
@@ -640,7 +640,10 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
SmallVector<int> laneData(resShapeSize, 1);
const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
- unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
+ unsigned bitwidth =
+ resVectorTy.getElementType().isIndex()
+ ? indexBitWidth
+ : resVectorTy.getElementType().getIntOrFloatBitWidth();
int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
int packedDataSize = subgroupSize * packingFactor;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 87835fb191604..1c9f06e92dff2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -321,6 +321,7 @@ class LayoutInfoPropagation
: public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
private:
xegpu::LayoutKind layoutKind;
+ unsigned indexBitWidth;
void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
ArrayRef<const LayoutInfoLattice *> results);
@@ -396,9 +397,9 @@ class LayoutInfoPropagation
public:
LayoutInfoPropagation(DataFlowSolver &solver,
SymbolTableCollection &symbolTable,
- xegpu::LayoutKind layoutKind)
+ xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
- layoutKind(layoutKind) {}
+ layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
LogicalResult
@@ -958,7 +959,8 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
- layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
+ layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch,
+ indexBitWidth);
xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
requiredResLayoutAttr);
@@ -1156,11 +1158,12 @@ class RunLayoutInfoPropagation {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
- RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+ RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
+ unsigned indexBitWidth)
: target(op) {
SymbolTableCollection symbolTable;
loadBaselineAnalyses(solver);
- solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
+ solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
(void)solver.initializeAndRun(op);
}
@@ -1544,8 +1547,9 @@ struct XeGPUPropagateLayoutPass final
} // namespace
LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
- LayoutKind layoutKind, bool printOnly) {
- RunLayoutInfoPropagation analysis(target, layoutKind);
+ LayoutKind layoutKind,
+ unsigned indexBitWidth, bool printOnly) {
+ RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
// Print the analysis result and exit. (for debugging purposes)
if (printOnly) {
auto &os = llvm::outs();
@@ -1628,7 +1632,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
}
OpBuilder builder(&getContext());
if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
- this->printOnly))) {
+ this->indexBitWidth, this->printOnly))) {
signalPassFailure();
return;
}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5dd05e6cb0001..340aa92d2d425 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -334,3 +334,21 @@ gpu.module @test{
return
}
}
+
+
+// -----
+gpu.module @test{
+ // CHECK-LABEL: insert_strided_slice_index_type
+ // CHECK: vector.insert_strided_slice %{{.*}} {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>
+ // CHECK: xegpu.store %{{.*}} <{layout = #xegpu.layout<inst_data = [1, 16]>}>
+ func.func @insert_strided_slice_index_type(%arg0: i64) {
+ %vector_step_2d_placeholder = arith.constant dense<1> : vector<16x16xindex>
+ %vector_step_slice = arith.constant dense<12> : vector<1x16xindex>
+ %v = vector.insert_strided_slice %vector_step_slice, %vector_step_2d_placeholder
+ {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x16xindex> into vector<16x16xindex>
+ %cst = arith.constant dense<true> : vector<16x16xi1>
+ %data = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+ xegpu.store %data, %arg0[%v], %cst : vector<16x16xf16>, i64, vector<16x16xindex>, vector<16x16xi1>
+ return
+ }
+}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index b9b38fc1ce88b..ffd091b5154b3 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -377,7 +377,7 @@ struct TestXeGPUPropagateLayouts
signalPassFailure();
return;
}
- if (failed(xegpu::propagateLayouts(builder, getOperation(), kind))) {
+ if (failed(xegpu::propagateLayouts(builder, getOperation(), kind, 32))) {
signalPassFailure();
}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/184322
More information about the Mlir-commits
mailing list