[Mlir-commits] [mlir] [MLIR][XeGPU] Handle `index` element type in the layout propagation (PR #184322)

Artem Kroviakov llvmlistbot at llvm.org
Wed Mar 11 08:12:07 PDT 2026


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/184322

>From 895bfc04a70678f6df8bc633811fc4e7b6d8567e Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 11 Mar 2026 15:11:48 +0000
Subject: [PATCH] [MLIR][XeGPU] Handle `index` element type in the layout
 propagation

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td     |  6 +++++-
 .../Dialect/XeGPU/Transforms/XeGPULayoutImpl.h  |  3 ++-
 .../Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp |  1 +
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp   | 17 ++++++++++-------
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp   |  2 +-
 5 files changed, 19 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 6d4a568f614bd..4bee1752b271e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -55,7 +55,11 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
     Option<
     "layoutKind", "layout-kind", "std::string",
     /*default=*/"\"lane\"",
-    "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">
+    "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">,
+    Option<
+    "indexBitWidth", "index-bitwidth", "unsigned",
+    /*default=*/"64",
+    "Vectors of `index` type should also be distributable, inst-data and lower levels need to know the index size.">
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 2ae0ef3ae852d..55b18d4a19c55 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -34,7 +34,8 @@ class TensorDescType;
 namespace xegpu {
 
 LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
-                               LayoutKind layoutKind, bool printOnly = false);
+                               LayoutKind layoutKind, unsigned indexBitWidth,
+                               bool printOnly = false);
 
 LogicalResult resolveLayoutConflicts(Operation *target);
 
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 2c0346f4b2d56..fbb7bb8aeb4bc 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -62,6 +62,7 @@ void buildPreGPUCommonPassPipeline(
 void buildGPUPassPipeline(OpPassManager &pm,
                           const mlir::gpu::GPUToXeVMPipelineOptions &options) {
   xegpu::XeGPUPropagateLayoutOptions laneLayoutOptions;
+  laneLayoutOptions.indexBitWidth = options.use64bitIndex ? 64 : 32;
   laneLayoutOptions.layoutKind = "lane";
   pm.addNestedPass<ModuleOp>(createCSEPass());
   if (options.xegpuOpLevel == "workgroup") {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index b3fb00f35b167..8bf0f2aca60c5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -321,6 +321,7 @@ class LayoutInfoPropagation
     : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
 private:
   xegpu::LayoutKind layoutKind;
+  unsigned indexBitWidth;
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
@@ -396,9 +397,9 @@ class LayoutInfoPropagation
 public:
   LayoutInfoPropagation(DataFlowSolver &solver,
                         SymbolTableCollection &symbolTable,
-                        xegpu::LayoutKind layoutKind)
+                        xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable),
-        layoutKind(layoutKind) {}
+        layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
   LogicalResult
@@ -1179,11 +1180,12 @@ class RunLayoutInfoPropagation {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
 
-  RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+  RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
+                           unsigned indexBitWidth)
       : target(op) {
     SymbolTableCollection symbolTable;
     loadBaselineAnalyses(solver);
-    solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
+    solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
     (void)solver.initializeAndRun(op);
   }
 
@@ -1573,8 +1575,9 @@ struct XeGPUPropagateLayoutPass final
 } // namespace
 
 LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
-                                      LayoutKind layoutKind, bool printOnly) {
-  RunLayoutInfoPropagation analysis(target, layoutKind);
+                                      LayoutKind layoutKind,
+                                      unsigned indexBitWidth, bool printOnly) {
+  RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
   // Print the analysis result and exit. (for debugging purposes)
   if (printOnly) {
     auto &os = llvm::outs();
@@ -1657,7 +1660,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
   }
   OpBuilder builder(&getContext());
   if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
-                                     this->printOnly))) {
+                                     this->indexBitWidth, this->printOnly))) {
     signalPassFailure();
     return;
   }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index b9b38fc1ce88b..ffd091b5154b3 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -377,7 +377,7 @@ struct TestXeGPUPropagateLayouts
       signalPassFailure();
       return;
     }
-    if (failed(xegpu::propagateLayouts(builder, getOperation(), kind))) {
+    if (failed(xegpu::propagateLayouts(builder, getOperation(), kind, 32))) {
       signalPassFailure();
     }
   }



More information about the Mlir-commits mailing list