[Mlir-commits] [mlir] [MLIR][XeGPU] Consider alignment in dpas sg_layout creation (PR #181141)

Artem Kroviakov llvmlistbot at llvm.org
Tue Mar 3 03:00:49 PST 2026


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/181141

>From 6ec977bf76a4b63d02390e7eaff7a42f9561fc8c Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 3 Mar 2026 11:00:31 +0000
Subject: [PATCH] [MLIR][XeGPU] Handle `index` element type in the layout
 propagation

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  6 +++++-
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |  6 ++++--
 .../GPU/Pipelines/GPUToXeVMPipeline.cpp       |  1 +
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |  7 +++++--
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 20 +++++++++++--------
 .../XeGPU/propagate-layout-inst-data.mlir     | 18 +++++++++++++++++
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  2 +-
 7 files changed, 46 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index cb71f19da62f0..4686ec1d81bf1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -64,7 +64,11 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
     Option<
     "layoutKind", "layout-kind", "std::string",
     /*default=*/"\"lane\"",
-    "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">
+    "Propagate `subgroup` / `inst` / `lane` level of xegpu layouts.">,
+    Option<
+    "indexBitWidth", "index-bitwidth", "unsigned",
+    /*default=*/"64",
+    "Vectors of `index` type should also be distributable, inst-data and lower levels need to know the index size.">
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 3482d1b9401bb..07dcb35d70f18 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -34,7 +34,8 @@ class TensorDescType;
 namespace xegpu {
 
 LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
-                               LayoutKind layoutKind, bool printOnly = false);
+                               LayoutKind layoutKind, unsigned indexBitWidth,
+                               bool printOnly = false);
 
 LogicalResult resolveLayoutConflicts(Operation *target);
 
@@ -134,7 +135,8 @@ DistributeLayoutAttr setupBitCastResultLayout(
 /// Lane).
 DistributeLayoutAttr setupInsertStridedSliceResultLayout(
     LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
-    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
+    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch,
+    unsigned indexBitWidth);
 
 /// Sets up the anchor layout for a load gather operation.
 DistributeLayoutAttr
diff --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index d6a4dae4c5c32..af107ebacf41b 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -62,6 +62,7 @@ void buildPreGPUCommonPassPipeline(
 void buildGPUPassPipeline(OpPassManager &pm,
                           const mlir::gpu::GPUToXeVMPipelineOptions &options) {
   xegpu::XeGPUPropagateLayoutOptions laneLayoutOptions;
+  laneLayoutOptions.indexBitWidth = options.use64bitIndex ? 64 : 32;
   laneLayoutOptions.layoutKind = "lane";
   pm.addNestedPass<ModuleOp>(createCSEPass());
   pm.addNestedPass<ModuleOp>(createGpuXeVMAttachTarget());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7aa186bb22224..ed44c6ff00338 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -622,7 +622,7 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
 xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
     VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
-    const xegpu::uArch::uArch *uArch) {
+    const xegpu::uArch::uArch *uArch, unsigned indexBitWidth) {
 
   xegpu::DistributeLayoutAttr requiredResLayout;
   auto subgroupSize = uArch->getSubgroupSize();
@@ -640,7 +640,10 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
   SmallVector<int> laneData(resShapeSize, 1);
 
   const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
-  unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
+  unsigned bitwidth =
+      resVectorTy.getElementType().isIndex()
+          ? indexBitWidth
+          : resVectorTy.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
   int packedDataSize = subgroupSize * packingFactor;
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 87835fb191604..1c9f06e92dff2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -321,6 +321,7 @@ class LayoutInfoPropagation
     : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
 private:
   xegpu::LayoutKind layoutKind;
+  unsigned indexBitWidth;
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
@@ -396,9 +397,9 @@ class LayoutInfoPropagation
 public:
   LayoutInfoPropagation(DataFlowSolver &solver,
                         SymbolTableCollection &symbolTable,
-                        xegpu::LayoutKind layoutKind)
+                        xegpu::LayoutKind layoutKind, unsigned indexBitWidth)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable),
-        layoutKind(layoutKind) {}
+        layoutKind(layoutKind), indexBitWidth(indexBitWidth) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
   LogicalResult
@@ -958,7 +959,8 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
       getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
 
   auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
-      layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
+      layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch,
+      indexBitWidth);
 
   xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
                             requiredResLayoutAttr);
@@ -1156,11 +1158,12 @@ class RunLayoutInfoPropagation {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
 
-  RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+  RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind,
+                           unsigned indexBitWidth)
       : target(op) {
     SymbolTableCollection symbolTable;
     loadBaselineAnalyses(solver);
-    solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
+    solver.load<LayoutInfoPropagation>(symbolTable, layoutKind, indexBitWidth);
     (void)solver.initializeAndRun(op);
   }
 
@@ -1544,8 +1547,9 @@ struct XeGPUPropagateLayoutPass final
 } // namespace
 
 LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
-                                      LayoutKind layoutKind, bool printOnly) {
-  RunLayoutInfoPropagation analysis(target, layoutKind);
+                                      LayoutKind layoutKind,
+                                      unsigned indexBitWidth, bool printOnly) {
+  RunLayoutInfoPropagation analysis(target, layoutKind, indexBitWidth);
   // Print the analysis result and exit. (for debugging purposes)
   if (printOnly) {
     auto &os = llvm::outs();
@@ -1628,7 +1632,7 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
   }
   OpBuilder builder(&getContext());
   if (failed(xegpu::propagateLayouts(builder, getOperation(), layoutKind,
-                                     this->printOnly))) {
+                                     this->indexBitWidth, this->printOnly))) {
     signalPassFailure();
     return;
   }
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5dd05e6cb0001..340aa92d2d425 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -334,3 +334,21 @@ gpu.module @test{
     return
   }
 }
+
+
+// -----
+gpu.module @test{
+  // CHECK-LABEL: insert_strided_slice_index_type
+  // CHECK: vector.insert_strided_slice %{{.*}} {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>
+  // CHECK: xegpu.store %{{.*}} <{layout = #xegpu.layout<inst_data = [1, 16]>}>
+  func.func @insert_strided_slice_index_type(%arg0: i64) {
+    %vector_step_2d_placeholder = arith.constant dense<1> : vector<16x16xindex>
+    %vector_step_slice = arith.constant dense<12> : vector<1x16xindex>
+    %v = vector.insert_strided_slice %vector_step_slice, %vector_step_2d_placeholder
+      {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x16xindex> into vector<16x16xindex>
+    %cst = arith.constant dense<true> : vector<16x16xi1>
+    %data = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+    xegpu.store %data, %arg0[%v], %cst : vector<16x16xf16>, i64, vector<16x16xindex>, vector<16x16xi1>
+    return
+  }
+}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index b9b38fc1ce88b..ffd091b5154b3 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -377,7 +377,7 @@ struct TestXeGPUPropagateLayouts
       signalPassFailure();
       return;
     }
-    if (failed(xegpu::propagateLayouts(builder, getOperation(), kind))) {
+    if (failed(xegpu::propagateLayouts(builder, getOperation(), kind, 32))) {
       signalPassFailure();
     }
   }



More information about the Mlir-commits mailing list