[Mlir-commits] [mlir] 6d362a9 - [MLIR][XeGPU] Handle `index` element type in the layout propagation (#184322)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 11 08:26:12 PDT 2026


Author: Artem Kroviakov
Date: 2026-03-11T15:25:44Z
New Revision: 6d362a9210a27a7d7572a5b7bf6c9d69115578c2

URL: https://github.com/llvm/llvm-project/commit/6d362a9210a27a7d7572a5b7bf6c9d69115578c2
DIFF: https://github.com/llvm/llvm-project/commit/6d362a9210a27a7d7572a5b7bf6c9d69115578c2.diff

LOG: [MLIR][XeGPU] Handle `index` element type in the layout propagation (#184322)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
    mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
    mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
    mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 6d4a568f614bd..4bee1752b271e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -55,7 +55,11 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
     Option<
     "layoutKind", "layout-kind", "std::string",
     /*default=*/"\"lane\"",
-    "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">
+    "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">,
+    Option<
+    "indexBitWidth", "index-bitwidth", "unsigned",
+    /*default=*/"64",
+    "Vectors of `index` type should also be distributable, inst-data and lower levels need to know the index size.">
   ];
 }
 

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 2ae0ef3ae852d..55b18d4a19c55 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -34,7 +34,8 @@ class TensorDescType;
 namespace xegpu {
 
 LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
-                               LayoutKind layoutKind, bool printOnly = false);
+                               LayoutKind layoutKind, unsigned indexBitWidth,
+                               bool printOnly = false);
 
 LogicalResult resolveLayoutConflicts(Operation *target);
 

diff  --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 2c0346f4b2d56..fbb7bb8aeb4bc 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -62,6 +62,7 @@ void buildPreGPUCommonPassPipeline(
 void buildGPUPassPipeline(OpPassManager &pm,
                           const mlir::gpu::GPUToXeVMPipelineOptions &options) {
   xegpu::XeGPUPropagateLayoutOptions laneLayoutOptions;
+  laneLayoutOptions.indexBitWidth = options.use64bitIndex ? 64 : 32;
   laneLayoutOptions.layoutKind = "lane";
   pm.addNestedPass<ModuleOp>(createCSEPass());
   if (options.xegpuOpLevel == "workgroup") {

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index b3fb00f35b167..8bf0f2aca60c5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -321,6 +321,7 @@ class LayoutInfoPropagation
     : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
 private:
   xegpu::LayoutKind layoutKind;
+  unsigned indexBitWidth;
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
@@ -396,9 +397,9 @@ class LayoutInfoPropagation
 public:
   LayoutInfoPropagation(DataFlowSolver &solver,
                         SymbolTableCollection &symbolTable,
-                        xegpu::LayoutKind layoutKind)
+                        xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable),
-        layoutKind(layoutKind) {}
+        layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
   LogicalResult
@@ -1179,11 +1180,12 @@ class RunLayoutInfoPropagation {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
 
-  RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+  RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
+                           unsigned indexBitWidth)
       : target(op) {
     SymbolTableCollection symbolTable;
     loadBaselineAnalyses(solver);
-    solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
+    solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
     (void)solver.initializeAndRun(op);
   }
 
@@ -1573,8 +1575,9 @@ struct XeGPUPropagateLayoutPass final
 } // namespace
 
 LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
-                                      LayoutKind layoutKind, bool printOnly) {
-  RunLayoutInfoPropagation analysis(target, layoutKind);
+                                      LayoutKind layoutKind,
+                                      unsigned indexBitWidth, bool printOnly) {
+  RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
   // Print the analysis result and exit. (for debugging purposes)
   if (printOnly) {
     auto &os = llvm::outs();
@@ -1657,7 +1660,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
   }
   OpBuilder builder(&getContext());
   if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
-                                     this->printOnly))) {
+                                     this->indexBitWidth, this->printOnly))) {
     signalPassFailure();
     return;
   }

diff  --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index b9b38fc1ce88b..ffd091b5154b3 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -377,7 +377,7 @@ struct TestXeGPUPropagateLayouts
       signalPassFailure();
       return;
     }
-    if (failed(xegpu::propagateLayouts(builder, getOperation(), kind))) {
+    if (failed(xegpu::propagateLayouts(builder, getOperation(), kind, 32))) {
       signalPassFailure();
     }
   }


        


More information about the Mlir-commits mailing list