[Mlir-commits] [mlir] bba40ab - [MLIR][XeGPU] Decouple `inst_data` and `lane_layout` in propagation (#166941)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 10 05:14:16 PST 2025


Author: Artem Kroviakov
Date: 2025-11-10T14:14:11+01:00
New Revision: bba40ab4bd60e636e77362c46c181eafd377f541

URL: https://github.com/llvm/llvm-project/commit/bba40ab4bd60e636e77362c46c181eafd377f541
DIFF: https://github.com/llvm/llvm-project/commit/bba40ab4bd60e636e77362c46c181eafd377f541.diff

LOG: [MLIR][XeGPU] Decouple `inst_data` and `lane_layout` in propagation (#166941)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
    mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
    mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
    mlir/test/Dialect/XeGPU/propagate-layout.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 9c35c07a7e587..3f27d690f949b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -379,6 +379,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
   );
 
   let builders = [
+    AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data),
+      [{
+        auto sg_layout = DenseI32ArrayAttr();
+        auto sg_data = DenseI32ArrayAttr();
+        auto order = DenseI32ArrayAttr();
+        auto lane_layout = DenseI32ArrayAttr();
+        auto lane_data = DenseI32ArrayAttr();
+        return $_get($_ctxt, sg_layout, sg_data,
+                     DenseI32ArrayAttr::get($_ctxt, inst_data),
+                     lane_layout, lane_data, order);
+      }]>,
     AttrBuilder<(ins "llvm::ArrayRef<int32_t>": $inst_data,
                       "llvm::ArrayRef<int32_t>": $lane_layout,
                      "llvm::ArrayRef<int32_t>": $lane_data),

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index e42799689e490..12270af870b3b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -47,7 +47,7 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
     Option<
     "layoutKind", "layout-kind", "std::string",
     /*default=*/"\"lane\"",
-    "Propagate a `sg` / `inst` / `lane` level of xegpu layouts.">
+    "Propagate `inst` / `lane` level of xegpu layouts.">
   ];
 }
 

diff  --git a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
index 1a1485ba2e02c..b097d3a0c9686 100644
--- a/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
+++ b/mlir/lib/Dialect/GPU/Pipelines/GPUToXeVMPipeline.cpp
@@ -63,13 +63,20 @@ void buildGPUPassPipeline(OpPassManager &pm,
   if (options.xegpuOpLevel == "workgroup") {
     pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUWgToSgDistribute());
     pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
+    xegpu::XeGPUPropagateLayoutOptions layoutOptions;
+    layoutOptions.layoutKind = "inst";
+    pm.addNestedPass<gpu::GPUModuleOp>(
+        xegpu::createXeGPUPropagateLayout(layoutOptions));
     pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUBlocking());
     pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
     pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());
   }
   if (options.xegpuOpLevel == "subgroup" ||
       options.xegpuOpLevel == "workgroup") {
-    pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUPropagateLayout());
+    xegpu::XeGPUPropagateLayoutOptions layoutOptions;
+    layoutOptions.layoutKind = "lane";
+    pm.addNestedPass<gpu::GPUModuleOp>(
+        xegpu::createXeGPUPropagateLayout(layoutOptions));
     pm.addNestedPass<gpu::GPUModuleOp>(xegpu::createXeGPUSubgroupDistribute());
     pm.addNestedPass<gpu::GPUModuleOp>(createCanonicalizerPass());
     pm.addNestedPass<gpu::GPUModuleOp>(createCSEPass());

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4e1a539771d2f..b3a780abd3f12 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -53,6 +53,8 @@ using namespace mlir::dataflow;
 
 namespace {
 
+enum class LayoutKind { Lane, InstData };
+
 //===----------------------------------------------------------------------===//
 // LayoutInfo
 //===----------------------------------------------------------------------===//
@@ -166,7 +168,8 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
   llvm_unreachable("Join should not be triggered by layout propagation.");
 }
 
-/// Construct a new layout with the transposed lane layout and lane data.
+/// Construct a new layout with the transposed inst_data or lane_layout,
+/// lane_data.
 LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
   if (!isAssigned())
     return {};
@@ -186,12 +189,20 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
   SmallVector<int32_t> laneData;
   SmallVector<int32_t> instData;
   for (int64_t idx : permutation) {
-    laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
-    laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
-    instData.push_back(static_cast<int32_t>(getInstData()[idx]));
+    if (getLaneLayout().size()) {
+      laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
+      laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
+    }
+    if (getInstData().size())
+      instData.push_back(static_cast<int32_t>(getInstData()[idx]));
   }
-  return LayoutInfo(xegpu::LayoutAttr::get(storage.getContext(), instData,
-                                           laneLayout, laneData));
+  xegpu::LayoutAttr layoutAttr;
+  if (getLaneLayout().size())
+    layoutAttr =
+        xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
+  if (getInstData().size())
+    layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
+  return LayoutInfo(layoutAttr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -213,15 +224,14 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
 /// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
 static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
                                            unsigned rank,
-                                           const xegpu::uArch::uArch *uArch,
-                                           ArrayRef<int> instData) {
+                                           const xegpu::uArch::uArch *uArch) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1) {
     return LayoutInfo(
-        xegpu::LayoutAttr::get(ctx, instData, {uArch->getSubgroupSize()}, {1}));
+        xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1}));
   }
-  return LayoutInfo(xegpu::LayoutAttr::get(
-      ctx, instData, {1, uArch->getSubgroupSize()}, {1, 1}));
+  return LayoutInfo(
+      xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1}));
 }
 
 static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
@@ -236,7 +246,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(mlir::MLIRContext *ctx,
 /// Helper to get the default layout for a vector type.
 static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
                                            const xegpu::uArch::uArch *uArch,
-                                           ArrayRef<int> instData,
                                            unsigned packingSize,
                                            bool isScattered = false) {
   // Expecting a 1D or 2D vector.
@@ -247,16 +256,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch, instData);
+    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
   unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
   if (isScattered) {
-    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+    return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
                                              {uArch->getSubgroupSize(), 1},
                                              {1, packingFactor}));
   }
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(), instData,
+  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
                                            {1, uArch->getSubgroupSize()},
                                            {1, packingFactor}));
 }
@@ -264,7 +273,6 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy,
 /// Helper to get the default layout for a vector type.
 static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
                                            const xegpu::uArch::uArch *uArch,
-                                           ArrayRef<int> instData,
                                            unsigned packingSize,
                                            bool isScattered = false) {
   // Expecting a 1D or 2D vector.
@@ -275,18 +283,18 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
          "Expected int or float element type.");
   // If the rank is 1, then return default layout for 1D vector.
   if (tdescTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch, instData);
+    return getDefaultSIMTLayoutInfo(tdescTy.getContext(), 1, uArch);
   // Packing factor is determined by the element type bitwidth.
   unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
   int subgroupSize = uArch->getSubgroupSize();
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
   if (isScattered) {
     return LayoutInfo(xegpu::LayoutAttr::get(
-        tdescTy.getContext(), instData, {subgroupSize, 1}, {1, packingFactor}));
+        tdescTy.getContext(), {subgroupSize, 1}, {1, packingFactor}));
   }
 
   return LayoutInfo(xegpu::LayoutAttr::get(
-      tdescTy.getContext(), instData, {1, subgroupSize}, {1, packingFactor}));
+      tdescTy.getContext(), {1, subgroupSize}, {1, packingFactor}));
 }
 
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
@@ -298,7 +306,7 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy,
 static LayoutInfo
 getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
                                 const xegpu::uArch::uArch *uArch,
-                                ArrayRef<int> instData, unsigned packingSize) {
+                                unsigned packingSize) {
   Type elementTy = vectorTy.getElementType();
   assert(elementTy.isIntOrFloat() &&
          "Expected int or float type in DPAS operands");
@@ -310,10 +318,10 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
         {static_cast<int32_t>(packingSize / elementTy.getIntOrFloatBitWidth()),
          1});
     return LayoutInfo(
-        xegpu::LayoutAttr::get(vectorTy.getContext(), instData, layout, data));
+        xegpu::LayoutAttr::get(vectorTy.getContext(), layout, data));
   }
   // Otherwise, return the default layout for the vector type.
-  return getDefaultSIMTLayoutInfo(vectorTy, uArch, instData, packingSize);
+  return getDefaultSIMTLayoutInfo(vectorTy, uArch, packingSize);
 }
 
 //===----------------------------------------------------------------------===//
@@ -328,6 +336,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
 class LayoutInfoPropagation
     : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
 private:
+  LayoutKind layoutKind;
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
@@ -380,8 +389,10 @@ class LayoutInfoPropagation
 
 public:
   LayoutInfoPropagation(DataFlowSolver &solver,
-                        SymbolTableCollection &symbolTable)
-      : SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
+                        SymbolTableCollection &symbolTable,
+                        LayoutKind layoutKind)
+      : SparseBackwardDataFlowAnalysis(solver, symbolTable),
+        layoutKind(layoutKind) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
 
   LogicalResult
@@ -499,8 +510,14 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
           "No suitable instruction multiple found for the given shape.");
     instData = {instHeight, instWidth};
   }
-  auto prefetchLayout = getDefaultSIMTLayoutInfo(
-      tdescTy, uArch, instData, uArchInstruction->getPackedFormatBitSize());
+  LayoutInfo prefetchLayout;
+  if (layoutKind == LayoutKind::InstData)
+    prefetchLayout =
+        LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
+  else
+    prefetchLayout = getDefaultSIMTLayoutInfo(
+        tdescTy, uArch, uArchInstruction->getPackedFormatBitSize());
+
   // Propagate the layout to the source tensor descriptor.
   propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
 }
@@ -627,14 +644,24 @@ void LayoutInfoPropagation::visitDpasOp(
   SmallVector<int> instDataA = {maxALen, subgroupSize};
   SmallVector<int> instDataB = {subgroupSize, maxBLen};
 
-  propagateIfChanged(operands[0],
-                     operands[0]->meet(getSIMTLayoutInfoForDPASOperand(
-                         aTy, 0, uArch, instDataA,
-                         uArchInstruction->getPackedFormatBitSizeA())));
-  propagateIfChanged(operands[1],
-                     operands[1]->meet(getSIMTLayoutInfoForDPASOperand(
-                         bTy, 1, uArch, instDataB,
-                         uArchInstruction->getPackedFormatBitSizeB())));
+  LayoutInfo dpasALayout;
+  LayoutInfo dpasBLayout;
+  LayoutInfo dpasCLayout;
+
+  if (layoutKind == LayoutKind::InstData) {
+    dpasALayout =
+        LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
+    dpasBLayout =
+        LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataB));
+  } else {
+    dpasALayout = getSIMTLayoutInfoForDPASOperand(
+        aTy, 0, uArch, uArchInstruction->getPackedFormatBitSizeA());
+    dpasBLayout = getSIMTLayoutInfoForDPASOperand(
+        bTy, 1, uArch, uArchInstruction->getPackedFormatBitSizeB());
+  }
+
+  propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+  propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
   if (operands.size() > 2) {
     VectorType cTy = dpas.getAccType();
     const unsigned dataCLen = bTy.getShape().back();
@@ -645,10 +672,15 @@ void LayoutInfoPropagation::visitDpasOp(
       dpas.emitWarning(
           "No suitable instruction multiple found for the given shape.");
     SmallVector<int> instDataC = {maxALen, maxCLen};
-    propagateIfChanged(operands[2],
-                       operands[2]->meet(getSIMTLayoutInfoForDPASOperand(
-                           cTy, 2, uArch, instDataC,
-                           uArchInstruction->getPackedFormatBitSizeB())));
+
+    if (layoutKind == LayoutKind::InstData)
+      dpasCLayout =
+          LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataC));
+    else
+      dpasCLayout = getSIMTLayoutInfoForDPASOperand(
+          cTy, 2, uArch, uArchInstruction->getPackedFormatBitSizeB());
+
+    propagateIfChanged(operands[2], operands[2]->meet(dpasCLayout));
   }
 }
 
@@ -685,9 +717,15 @@ void LayoutInfoPropagation::visitStoreNdOp(
           "No suitable instruction multiple found for the given shape.");
     instData = {instHeight, instWidth};
   }
-  LayoutInfo storeLayout =
-      getDefaultSIMTLayoutInfo(store.getValueType(), uArch, instData,
-                               uArchInstruction->getPackedFormatBitSize());
+
+  LayoutInfo storeLayout;
+  if (layoutKind == LayoutKind::InstData)
+    storeLayout =
+        LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
+  else
+    storeLayout =
+        getDefaultSIMTLayoutInfo(store.getValueType(), uArch,
+                                 uArchInstruction->getPackedFormatBitSize());
   // Both operands should have the same layout
   for (LayoutInfoLattice *operand : operands)
     propagateIfChanged(operand, operand->meet(storeLayout));
@@ -818,9 +856,13 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     if (srcTdescTy.getChunkSizeAsInt() > 1)
       instData.push_back(chunkSize);
   }
-  LayoutInfo layout = getDefaultSIMTLayoutInfo(
-      payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
-      /*scattered*/ true);
+  LayoutInfo layout;
+  if (layoutKind == LayoutKind::InstData)
+    layout = LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
+  else
+    layout = getDefaultSIMTLayoutInfo(payloadTy, uArch,
+                                      uArch->getGeneralPackedFormatBitSize(),
+                                      /*scattered*/ true);
 
   // Mask operand should have 1D default layout.
   LayoutInfo maskLayout =
@@ -864,33 +906,36 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
     return;
   }
+  LayoutInfo payloadLayout;
   auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
   const int subgroupSize = uArch->getSubgroupSize();
 
-  auto payloadShape = payloadTy.getShape();
-  if (payloadShape.size() > 1)
-    assert(
-        payloadShape[0] == subgroupSize &&
-        "Expected the first dimension of 2D tensor descriptor to be equal to "
-        "subgroup size.");
-
-  SmallVector<int> instData{subgroupSize};
-  if (auto chunkSize = storeScatter.getChunkSize().value_or(0); chunkSize > 1)
-    instData.push_back(chunkSize);
-  else if (auto dstTdescTy =
-               dyn_cast<xegpu::TensorDescType>(storeScatter.getDestType())) {
-    if (dstTdescTy.getChunkSizeAsInt() > 1)
-      instData.push_back(chunkSize);
-  }
-
-  LayoutInfo payloadLayout;
-
   if (auto layout = storeScatter.getLayoutAttr()) {
     payloadLayout = LayoutInfo(layout);
   } else {
-    payloadLayout = getDefaultSIMTLayoutInfo(
-        payloadTy, uArch, instData, uArch->getGeneralPackedFormatBitSize(),
-        /*scattered=*/true);
+    if (layoutKind == LayoutKind::InstData) {
+      SmallVector<int> instData{subgroupSize};
+      if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
+          chunkSize > 1)
+        instData.push_back(chunkSize);
+      else if (auto dstTdescTy = dyn_cast<xegpu::TensorDescType>(
+                   storeScatter.getDestType())) {
+        if (dstTdescTy.getChunkSizeAsInt() > 1)
+          instData.push_back(chunkSize);
+      }
+      payloadLayout = LayoutInfo(
+          xegpu::LayoutAttr::get(storeScatter.getContext(), instData));
+    } else {
+      auto payloadShape = payloadTy.getShape();
+      if (payloadShape.size() > 1)
+        assert(payloadShape[0] == subgroupSize &&
+               "Expected the first dimension of 2D tensor descriptor to be "
+               "equal to "
+               "subgroup size.");
+      payloadLayout = getDefaultSIMTLayoutInfo(
+          payloadTy, uArch, uArch->getGeneralPackedFormatBitSize(),
+          /*scattered=*/true);
+    }
   }
 
   LayoutInfo maskLayout =
@@ -916,10 +961,10 @@ class RunLayoutInfoPropagation {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
 
-  RunLayoutInfoPropagation(Operation *op) : target(op) {
+  RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
     SymbolTableCollection symbolTable;
     loadBaselineAnalyses(solver);
-    solver.load<LayoutInfoPropagation>(symbolTable);
+    solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
     (void)solver.initializeAndRun(op);
   }
 
@@ -1159,7 +1204,18 @@ struct XeGPUPropagateLayoutPass final
 } // namespace
 
 void XeGPUPropagateLayoutPass::runOnOperation() {
-  auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
+  LayoutKind layoutKind;
+  if (this->layoutKind == "lane") {
+    layoutKind = LayoutKind::Lane;
+  } else if (this->layoutKind == "inst") {
+    layoutKind = LayoutKind::InstData;
+  } else {
+    getOperation()->emitError("Unsupported layout kind option: " +
+                              this->layoutKind);
+    signalPassFailure();
+    return;
+  }
+  RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
   // Print the analysis result and exit. (for debugging purposes)
   if (printOnly) {
     auto &os = llvm::outs();
@@ -1173,8 +1229,6 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
       return {};
     xegpu::DistributeLayoutAttr layoutAttr =
         cast<xegpu::DistributeLayoutAttr>(layout.get());
-    if (this->layoutKind == "lane")
-      layoutAttr = layoutAttr.dropInstData();
     if (layout.isSliceLayout())
       return cast<xegpu::SliceAttr>(layoutAttr);
     return cast<xegpu::LayoutAttr>(layoutAttr);

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 58461b8be52c4..c31ef323a94d2 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -2,17 +2,17 @@
 
 // CHECK-LABEL: func.func @dpas_f16(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<16x16xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xf32>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
-// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]]  {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]]  {layout_result_0 = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
-// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<8x16xf32>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<16x16xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]]  {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: !xegpu.tensor_desc<8x16xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<8x16xf16>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]]  {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} :
+// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf16>
+// CHECK: %[[T4:.*]] = xegpu.dpas %[[T2]], %[[T3]], %[[CST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
 // CHECK-SAME: vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
-// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: %[[T5:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>
+// CHECK: xegpu.store_nd %[[T4]], %[[T5]] : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<inst_data = [8, 16]>>
 gpu.module @test {
 
 func.func @dpas_f16(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
@@ -46,18 +46,18 @@ gpu.module @test_kernel {
     %out:3 = scf.for %k = %c0 to %c1024 step %c32
       iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
       -> (!xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>, !xegpu.tensor_desc<16x32xf16>) {
-      //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
-      //CHECK-SAME: !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x32xf16>
+      //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} :
+      //CHECK-SAME: !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x32xf16>
       %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16>
       %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16> -> vector<16x32xf16>
 
-      //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x32xf16>
+      //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<16x32xf16>
       %c = arith.addf %a, %b : vector<16x32xf16>
 
-      //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>>
+      //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>>
       xegpu.store_nd %c, %arg2: vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16>
 
-      //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x32xf16, #xegpu.layout<inst_data = [8, 16]>>
       %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16>
       %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16>
       %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16>
@@ -85,18 +85,18 @@ gpu.module @test_kernel {
     %out:3 = scf.for %k = %c0 to %c1024 step %c32
       iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
       -> (!xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>, !xegpu.tensor_desc<12x32xf16>) {
-      //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
-      //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<12x32xf16>
+      //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} :
+      //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>> -> vector<12x32xf16>
       %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16>
       %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<12x32xf16> -> vector<12x32xf16>
 
-      //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<12x32xf16>
+      //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16]>} : vector<12x32xf16>
       %c = arith.addf %a, %b : vector<12x32xf16>
 
-      //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>>
+      //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>>
       xegpu.store_nd %c, %arg2: vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16>
 
-      //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16]>>
       %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16>
       %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16>
       %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<12x32xf16>
@@ -114,7 +114,7 @@ gpu.module @test {
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
 // CHECK: %{{.*}} = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}>
-// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 8], lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+// CHECK-SAME: {layout_result_0 = #xegpu.layout<inst_data = [16, 8]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
 // CHECK: xegpu.store %0, %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 61e315d0d2080..eb004932af4be 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=pvc' -xegpu-propagate-layout="layout-kind=lane" -split-input-file %s | FileCheck %s
 
 gpu.module @test {
 // CHECK-LABEL: func.func @dpas_f16(


        


More information about the Mlir-commits mailing list