[Mlir-commits] [mlir] [mlir][xegpu] Add vector layout conflict handling in XeGPU layout propagation pass. (PR #182402)

Charitha Saumya llvmlistbot at llvm.org
Thu Feb 19 15:47:03 PST 2026


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/182402

>From 0f0f0a6129a7e1c63cc7706b4a1b98b737779521 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 18 Feb 2026 22:28:47 +0000
Subject: [PATCH 1/4] some cleanups

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 35 ++++++++++---------
 1 file changed, 19 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index bc309c9029878..d445d3d476efa 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include <utility>
+
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Analysis/DataFlow/Utils.h"
@@ -503,10 +505,12 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
   }
   if (layoutKind == xegpu::LayoutKind::InstData) {
     return !(anchorLayout.getEffectiveInstDataAsInt().empty());
-  } else if (layoutKind == xegpu::LayoutKind::Lane) {
+  }
+  if (layoutKind == xegpu::LayoutKind::Lane) {
     return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
              anchorLayout.getEffectiveLaneDataAsInt().empty());
-  } else if (layoutKind == xegpu::LayoutKind::Subgroup) {
+  }
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
     return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
              anchorLayout.getEffectiveSgDataAsInt().empty());
   }
@@ -574,7 +578,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
     // prefetch.
     auto tdescTy = prefetch.getTensorDescType();
 
-    auto uArch = getUArch(getChipStr(prefetch).value_or(""));
+    const auto *uArch = getUArch(getChipStr(prefetch).value_or(""));
     const auto *uArchInstruction =
         dyn_cast<xegpu::uArch::Subgroup2DBlockPrefetchInstruction>(
             uArch->getInstruction(
@@ -628,7 +632,7 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   VectorType sourceTy = reduction.getSourceVectorType();
   SmallVector<int64_t> reductionDims(reduction.getReductionDims());
 
-  auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
+  const auto *uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
   auto consumerLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
 
@@ -683,7 +687,6 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
       xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
 
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
-  return;
 }
 
 void LayoutInfoPropagation::visitShapeCastOp(
@@ -738,7 +741,7 @@ void LayoutInfoPropagation::visitDpasOp(
     dpasBLayout = LayoutInfo(anchorLayoutB);
     dpasCDLayout = LayoutInfo(anchorLayoutCD);
   } else {
-    auto uArch = getUArch(getChipStr(dpas).value_or(""));
+    const auto *uArch = getUArch(getChipStr(dpas).value_or(""));
     VectorType aTy = dpas.getLhsType();
     VectorType bTy = dpas.getRhsType();
     VectorType cdTy = dpas.getResultType();
@@ -794,7 +797,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
   if (hasParamsOfLayoutKind(anchorLayout)) {
     storeLayout = LayoutInfo(anchorLayout);
   } else {
-    auto uArch = getUArch(getChipStr(store).value_or(""));
+    const auto *uArch = getUArch(getChipStr(store).value_or(""));
     const auto *uArchInstruction =
         dyn_cast<xegpu::uArch::Subgroup2DBlockStoreInstruction>(
             uArch->getInstruction(
@@ -923,7 +926,7 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
 
   auto consumerLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
-  auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
+  const auto *uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
   auto requiredResLayoutAttr = setupBitCastResultLayout(
       layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
 
@@ -953,7 +956,8 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
 
   auto consumerLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
-  auto uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
+  const auto *uArch =
+      getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
 
   auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
       layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
@@ -967,7 +971,6 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
   propagateIfChanged(operands[1],
                      operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
-  return;
 }
 
 /// Propagate the layout of the result to the tensor descriptor, mask and offset
@@ -977,7 +980,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     ArrayRef<const LayoutInfoLattice *> results) {
   xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
   xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
-  auto uArch = getUArch(getChipStr(load).value_or(""));
+  const auto *uArch = getUArch(getChipStr(load).value_or(""));
   auto subgroupSize = uArch->getSubgroupSize();
   VectorType resVecTy = load.getValueType();
   int chunkSize = load.getChunkSize().value_or(1);
@@ -1036,7 +1039,7 @@ void LayoutInfoPropagation::visitCreateDescOp(
   // Need the layout of the descriptor to propagate to the operands.
   if (!descLayout.isAssigned())
     return;
-  auto uArch = getUArch(getChipStr(createDesc).value_or(""));
+  const auto *uArch = getUArch(getChipStr(createDesc).value_or(""));
   // For offset operand propagate 1D default layout.
   LayoutInfo layout = getDefaultSIMTLayoutInfo(createDesc->getContext(), 1,
                                                uArch->getSubgroupSize());
@@ -1051,7 +1054,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
 
   xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
   xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
-  auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+  const auto *uArch = getUArch(getChipStr(storeScatter).value_or(""));
   auto subgroupSize = uArch->getSubgroupSize();
   VectorType srcVecTy = storeScatter.getValueType();
   int chunkSize = storeScatter.getChunkSize().value_or(1);
@@ -1113,7 +1116,7 @@ void LayoutInfoPropagation::visitLoadMatrixOp(
     VectorType resVecTy =
         llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
     assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
-    auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
+    const auto *uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
     auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
         layoutKind, resVecTy, consumerLayoutAttr, uArch);
     loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
@@ -1132,7 +1135,7 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
     VectorType srcVecTy =
         llvm::cast<VectorType>(storeMatrix.getData().getType());
     assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
-    auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+    const auto *uArch = getUArch(getChipStr(storeMatrix).value_or(""));
     auto requiredAnchorLayoutAttr =
         xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
     storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
@@ -1499,7 +1502,7 @@ struct XeGPUPropagateLayoutPass final
   XeGPUPropagateLayoutPass() = default;
   XeGPUPropagateLayoutPass(const XeGPUPropagateLayoutPass &other) = default;
   XeGPUPropagateLayoutPass(xegpu::XeGPUPropagateLayoutOptions options)
-      : XeGPUPropagateLayoutBase(options) {}
+      : XeGPUPropagateLayoutBase(std::move(options)) {}
   void runOnOperation() override;
 };
 

>From addae57c05b2ece3b36629f419fb8692289fd8b0 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 19 Feb 2026 23:30:27 +0000
Subject: [PATCH 2/4] add tests

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 174 ++++++++++++++---
 .../XeGPU/resolve-layout-conflicts.mlir       | 177 +++++++++++++++---
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |   3 +-
 3 files changed, 304 insertions(+), 50 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index d445d3d476efa..56b3656fc044c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -24,7 +24,9 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/Region.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -1235,6 +1237,115 @@ namespace {
 //===----------------------------------------------------------------------===//
 // ResolveLayoutConflicts
 //===----------------------------------------------------------------------===//
+
+/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
+/// function tries to find the defining CreateNdDescOp recursively accross
+/// control-flow boundaries.
+static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
+  // Try to get the defining CreateNdDescOp of the tensor descriptor.
+  auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
+  if (definingOp)
+    return definingOp;
+  // If tdescValue is an argument, try to get the tied init value from the
+  // parent loop-like op.
+  if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
+    auto *parentOp = arg.getOwner()->getParentOp();
+    if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+      OpOperand *tiedInit = loop.getTiedLoopInit(arg);
+      if (tiedInit)
+        return getDefiningCreateNdDescOp(tiedInit->get());
+    }
+  }
+  // If not found, return null.
+  return nullptr;
+}
+
+static xegpu::DistributeLayoutAttr
+getExpectedLayoutAt(OpOperand &operand,
+                    xegpu::DistributeLayoutAttr currLayout) {
+  Operation *op = operand.getOwner();
+  unsigned idx = operand.getOperandNumber();
+
+  // For vector::BroadcastOp, infer the source layout from the result layout.
+  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
+    auto resLayout = xegpu::getDistributeLayoutAttr(broadcast->getResult(0));
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!srcTy)
+      return xegpu::DistributeLayoutAttr();
+    return xegpu::inferBroadcastSourceLayout(
+        resLayout, broadcast.getResultVectorType().getShape(),
+        srcTy.getShape());
+  }
+
+  // For vector::MultiDimReductionOp, infer source layout from result layout
+  // using reduction dims. Acc operand is expected to have the same layout as
+  // the result.
+  if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
+    auto resLayout = xegpu::getDistributeLayoutAttr(reduction->getResult(0));
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    if (idx == 0) {
+      SmallVector<int64_t> reductionDims(reduction.getReductionDims());
+      return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+    }
+    if (idx == 1)
+      return resLayout;
+  }
+
+  // For vector::BitCastOp, infer source layout from result layout using
+  // element type bitwidths.
+  if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
+    auto resLayout = xegpu::getDistributeLayoutAttr(bitcast->getResult(0));
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    int resElemBitWidth =
+        bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
+    int srcElemBitWidth =
+        bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
+    return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+                                           srcElemBitWidth);
+  }
+
+  // For vector::ShapeCastOp, infer source layout from result layout using
+  // shapes.
+  if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
+    auto resLayout = xegpu::getDistributeLayoutAttr(shapeCast->getResult(0));
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    return xegpu::inferShapeCastSourceLayout(
+        resLayout, shapeCast.getResultVectorType().getShape(),
+        shapeCast.getSourceVectorType().getShape());
+  }
+
+  // For vector::InsertStridedSliceOp, infer source layout from result layout.
+  // Dest vector must have the same layout as the result.
+  if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
+    auto resLayout = xegpu::getDistributeLayoutAttr(insertSlice->getResult(0));
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    if (idx == 0)
+      return xegpu::inferInsertStridedSliceSourceLayout(
+          resLayout, insertSlice.getDestVectorType().getShape(),
+          insertSlice.getSourceVectorType().getShape());
+    if (idx == 1)
+      return resLayout;
+  }
+  // For elementwise operations, all operands must have the same layout as the
+  // result.
+  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+    auto resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    return resLayout;
+  }
+  // TODO: Handle more cases as needed here.
+  // Fallback to currently assigned layout for all other cases. This assumes no
+  // conflicts.
+  return currLayout;
+}
+
 struct ResolveLayoutConflicts {
   ResolveLayoutConflicts(Operation *parentOp)
       : parentOp(parentOp), builder(parentOp->getContext()) {}
@@ -1259,12 +1370,19 @@ LogicalResult ResolveLayoutConflicts::run() {
       if (isa<xegpu::AnchorLayoutInterface>(op) &&
           isa<xegpu::TensorDescType>(operandType)) {
         auto res = resolveTensorDescConsumer(operand);
-        return succeeded(res) ? WalkResult::advance() : WalkResult::interrupt();
+        if (failed(res)) {
+          DBGS() << "Failed to resolve tensor descriptor consumer: " << *op
+                 << "\n";
+          return WalkResult::interrupt();
+        }
       }
       // Handle conflicts in vector operands.
       if (isa<VectorType>(operandType)) {
         auto res = resolveVectorConsumer(operand);
-        return succeeded(res) ? WalkResult::advance() : WalkResult::interrupt();
+        if (failed(res)) {
+          DBGS() << "Failed to resolve vector consumer: " << *op << "\n";
+          return WalkResult::interrupt();
+        }
       }
     }
     return WalkResult::advance();
@@ -1273,32 +1391,36 @@ LogicalResult ResolveLayoutConflicts::run() {
   return r.wasInterrupted() ? failure() : success();
 }
 
-/// Helper to get the defining CreateNdDescOp of a tensor descriptor value. This
-/// function tries to find the defining CreateNdDescOp recursively accross
-/// control-flow boundaries.
-static xegpu::CreateNdDescOp getDefiningCreateNdDescOp(Value tdescValue) {
-  // Try to get the defining CreateNdDescOp of the tensor descriptor.
-  auto definingOp = tdescValue.getDefiningOp<xegpu::CreateNdDescOp>();
-  if (definingOp)
-    return definingOp;
-  // If tdescValue is an argument, try to get the tied init value from the
-  // parent loop-like op.
-  if (auto arg = dyn_cast<BlockArgument>(tdescValue)) {
-    auto *parentOp = arg.getOwner()->getParentOp();
-    if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
-      OpOperand *tiedInit = loop.getTiedLoopInit(arg);
-      if (tiedInit)
-        return getDefiningCreateNdDescOp(tiedInit->get());
-    }
-  }
-  // If not found, return null.
-  return nullptr;
-}
-
 LogicalResult
 ResolveLayoutConflicts::resolveVectorConsumer(OpOperand &operand) {
-  // TODO: Implement vector consumer layout conflict resolution. Requires layout
-  // utilities.
+  Value vectorValue = operand.get();
+  Operation *consumerOp = operand.getOwner();
+  // Get the current layout of the vector value.
+  auto currLayout = xegpu::getDistributeLayoutAttr(vectorValue);
+  if (!currLayout) {
+    consumerOp->emitError("Vector operand has no layout assigned.");
+    return failure();
+  }
+
+  // Get the expected layout at this operand.
+  auto expectedLayout = getExpectedLayoutAt(operand, currLayout);
+  if (!expectedLayout) {
+    consumerOp->emitError("No expected layout found for vector operand.");
+    return failure();
+  }
+
+  // If layouts are same, no conflict exists, return success.
+  if (expectedLayout.isEqualTo(currLayout))
+    return success();
+
+  // Insert a convert_layout op to resolve the conflict.
+  builder.setInsertionPointAfterValue(vectorValue);
+  auto convertOp = xegpu::ConvertLayoutOp::create(
+      builder, consumerOp->getLoc(), vectorValue.getType(), vectorValue,
+      currLayout, expectedLayout);
+
+  // Update the operand to use the converted value.
+  operand.set(convertOp.getResult());
   return success();
 }
 
diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
index d1dbe8bcff509..4b601f9f85b3a 100644
--- a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
+++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
@@ -1,8 +1,13 @@
-// RUN: mlir-opt --test-xegpu-resolve-layout-conflicts -split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-xegpu-resolve-layout-conflicts="layout-kind=inst" \
+// RUN: -allow-unregistered-dialect -split-input-file %s | FileCheck %s
 
-#load_lo = #xegpu.layout<inst_data = [8, 16]>
-#prefetch_lo = #xegpu.layout<inst_data = [16, 16]>
-#load_lo1 = #xegpu.layout<inst_data = [32, 16]>
+#inst_data_8x16 = #xegpu.layout<inst_data = [8, 16]>
+#inst_data_16x16 = #xegpu.layout<inst_data = [16, 16]>
+#inst_data_32x16 = #xegpu.layout<inst_data = [32, 16]>
+#inst_data_16 = #xegpu.layout<inst_data = [16]>
+#inst_data_32 = #xegpu.layout<inst_data = [32]>
+#inst_data_1x2x16 = #xegpu.layout<inst_data = [1, 2, 16]>
+#inst_data_1x32 = #xegpu.layout<inst_data = [1, 32]>
 gpu.module @test {
 
 // CHECK-LABEL:   func.func @load_nd_with_conflicting_tensor_desc
@@ -15,10 +20,10 @@ gpu.module @test {
 func.func @load_nd_with_conflicting_tensor_desc(%arg0: memref<64x64xf16>) -> vector<16x16xf16> {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16>
-    -> !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
-  %1 = xegpu.load_nd %0 [%c0, %c0] {layout = #load_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+    -> !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
+  %1 = xegpu.load_nd %0 [%c0, %c0] {layout = #inst_data_8x16} : !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
     -> vector<16x16xf16>
-  xegpu.prefetch_nd %0 [%c0, %c0] {layout = #prefetch_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+  xegpu.prefetch_nd %0 [%c0, %c0] {layout = #inst_data_16x16} : !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
   return %1 : vector<16x16xf16>
 }
 
@@ -36,17 +41,16 @@ func.func @load_nd_with_conflicting_tensor_desc(%arg0: memref<64x64xf16>) -> vec
 // CHECK-SAME:        !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [32, 16]>> -> vector<32x16xf16>
 // CHECK-NEXT:      xegpu.prefetch_nd %[[T1]][%[[C0]], %[[C0]]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> :
 // CHECK-SAME:        !xegpu.tensor_desc<32x16xf16, #xegpu.layout<inst_data = [16, 16]>>
-func.func @multiple_tensor_desc_conflicts(%arg0: memref<64x64xf16>) -> vector<32x16xf16> {
+func.func @multiple_tensor_desc_conflicts(%arg0: memref<64x64xf16>) -> (vector<32x16xf16>, vector<32x16xf16>) {
   %c0 = arith.constant 0 : index
   %tdesc1 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16>
-    -> !xegpu.tensor_desc<32x16xf16, #load_lo>
-  %load1 = xegpu.load_nd %tdesc1 [%c0, %c0] {layout = #load_lo} : !xegpu.tensor_desc<32x16xf16, #load_lo>
+    -> !xegpu.tensor_desc<32x16xf16, #inst_data_8x16>
+  %load1 = xegpu.load_nd %tdesc1 [%c0, %c0] {layout = #inst_data_8x16} : !xegpu.tensor_desc<32x16xf16, #inst_data_8x16>
     -> vector<32x16xf16>
-  %load2 = xegpu.load_nd %tdesc1 [%c0, %c0] {layout = #load_lo1} : !xegpu.tensor_desc<32x16xf16, #load_lo>
+  %load2 = xegpu.load_nd %tdesc1 [%c0, %c0] {layout = #inst_data_32x16} : !xegpu.tensor_desc<32x16xf16, #inst_data_8x16>
     -> vector<32x16xf16>
-  xegpu.prefetch_nd %tdesc1 [%c0, %c0] {layout = #prefetch_lo} : !xegpu.tensor_desc<32x16xf16, #load_lo>
-  %result = arith.addf %load1, %load2 : vector<32x16xf16>
-  return %result : vector<32x16xf16>
+  xegpu.prefetch_nd %tdesc1 [%c0, %c0] {layout = #inst_data_16x16} : !xegpu.tensor_desc<32x16xf16, #inst_data_8x16>
+  return %load1, %load2 : vector<32x16xf16>, vector<32x16xf16>
 }
 
 // CHECK-LABEL:   func.func @load_nd_with_conflicting_tensor_desc_in_loop
@@ -66,16 +70,145 @@ func.func @load_nd_with_conflicting_tensor_desc_in_loop(%arg0: memref<64x64xf16>
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
-  %cst = arith.constant dense<0.0> : vector<16x16xf16>
+  %cst = arith.constant {layout_result_0 = #inst_data_8x16} dense<0.0> : vector<16x16xf16>
   %0 = xegpu.create_nd_tdesc %arg0 : memref<64x64xf16>
-    -> !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
-  %1:2 = scf.for %i = %c0 to %c4 step %c1 iter_args(%acc = %cst, %tdesc = %0) -> (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #prefetch_lo>) {
-    %2 = xegpu.load_nd %tdesc [%c0, %c0] {layout = #load_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+    -> !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
+  %1:2 = scf.for %i = %c0 to %c4 step %c1 iter_args(%acc = %cst, %tdesc = %0)
+    -> (vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>) {
+    %2 = xegpu.load_nd %tdesc [%c0, %c0] {layout = #inst_data_8x16} : !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
       -> vector<16x16xf16>
-    %3 = arith.addf %acc, %2 : vector<16x16xf16>
-    scf.yield %3, %tdesc : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
-  }
-  xegpu.prefetch_nd %0 [%c0, %c0] {layout = #prefetch_lo} : !xegpu.tensor_desc<16x16xf16, #prefetch_lo>
+    %3 = arith.addf %acc, %2 {layout_result_0 = #inst_data_8x16} : vector<16x16xf16>
+    scf.yield %3, %tdesc : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
+  } {layout_result_0 = #inst_data_8x16}
+  xegpu.prefetch_nd %0 [%c0, %c0] {layout = #inst_data_16x16} : !xegpu.tensor_desc<16x16xf16, #inst_data_16x16>
   return %1#0 : vector<16x16xf16>
 }
+
+
+// CHECK-LABEL: func.func @elementwise_conflict
+// CHECK-DAG:     %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : () -> vector<32x32xf16>
+// CHECK-DAG:     %[[V1:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [32, 16]>} : () -> vector<32x32xf16>
+// CHECK-DAG:     %[[CVT:.*]] = xegpu.convert_layout %[[V1]]
+// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [32, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}>
+// CHECK-SAME:      : vector<32x32xf16>
+// CHECK:         %[[ADD:.*]] = arith.addf %[[V0]], %[[CVT]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<32x32xf16>
+// CHECK:         return %[[ADD]] : vector<32x32xf16>
+func.func @elementwise_conflict() -> vector<32x32xf16> {
+  %0 = "some_op"() {layout_result_0 = #inst_data_8x16} : () -> vector<32x32xf16>
+  %1 = "some_op"() {layout_result_0 = #inst_data_32x16} : () -> vector<32x32xf16>
+  %2 = arith.addf %0, %1 {layout_result_0 = #inst_data_8x16} : vector<32x32xf16>
+  return %2 : vector<32x32xf16>
+}
+
+// CHECK-LABEL: func.func @broadcast_source_conflict
+// CHECK:         %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [16]>} : () -> vector<16xf16>
+// CHECK:         %[[CVT:.*]] = xegpu.convert_layout %[[V0]]
+// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [16]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}>
+// CHECK-SAME:      : vector<16xf16>
+// CHECK:         %[[BC:.*]] = vector.broadcast %[[CVT]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf16> to vector<16x16xf16>
+// CHECK:         return %[[BC]] : vector<16x16xf16>
+func.func @broadcast_source_conflict() -> vector<16x16xf16> {
+  %0 = "some_op"() {layout_result_0 = #inst_data_16} : () -> vector<16xf16>
+  %1 = vector.broadcast %0 {layout_result_0 = #inst_data_16x16} : vector<16xf16> to vector<16x16xf16>
+  return %1 : vector<16x16xf16>
+}
+
+// CHECK-LABEL: func.func @shapecast_source_conflict
+// CHECK:         %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [1, 2, 16]>} : () -> vector<2x4x32xf16>
+// CHECK:         %[[CVT:.*]] = xegpu.convert_layout %[[V0]]
+// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [1, 2, 16]>, target_layout = #xegpu.layout<inst_data = [1, 1, 32]>}>
+// CHECK-SAME:      : vector<2x4x32xf16>
+// CHECK:         %[[SC:.*]] = vector.shape_cast %[[CVT]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<inst_data = [1, 32]>} : vector<2x4x32xf16> to vector<1x256xf16>
+// CHECK:         return %[[SC]] : vector<1x256xf16>
+func.func @shapecast_source_conflict() -> vector<1x256xf16> {
+  %0 = "some_op"() {layout_result_0 = #inst_data_1x2x16} : () -> vector<2x4x32xf16>
+  %1 = vector.shape_cast %0 {layout_result_0 = #inst_data_1x32}  : vector<2x4x32xf16> to vector<1x256xf16>
+  return %1 : vector<1x256xf16>
+}
+
+// CHECK-LABEL: func.func @bitcast_source_conflict
+// CHECK:         %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} : () -> vector<32x16xf32>
+// CHECK:         %[[CVT:.*]] = xegpu.convert_layout %[[V0]]
+// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [1, 16]>, target_layout = #xegpu.layout<inst_data = [16, 16]>}>
+// CHECK-SAME:      : vector<32x16xf32>
+// CHECK:         %[[BC:.*]] = vector.bitcast %[[CVT]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<inst_data = [16, 32]>} : vector<32x16xf32> to vector<32x32xf16>
+// CHECK:         return %[[BC]] : vector<32x32xf16>
+func.func @bitcast_source_conflict() -> vector<32x32xf16> {
+  %0 = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} : () -> vector<32x16xf32>
+  %1 = vector.bitcast %0 {layout_result_0 = #xegpu.layout<inst_data = [16, 32]>} : vector<32x16xf32> to vector<32x32xf16>
+  return %1 : vector<32x32xf16>
+}
+
+// CHECK-LABEL: func.func @multireduction_source_conflict
+// CHECK-DAG:     %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [32, 16]>} : () -> vector<32x32xf16>
+// CHECK-DAG:     %[[CVT0:.*]] = xegpu.convert_layout %[[V0]]
+// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [32, 16]>, target_layout = #xegpu.layout<inst_data = [16, 16]>}>
+// CHECK-SAME:      : vector<32x32xf16>
+// CHECK-DAG:     %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [32]>}
+// CHECK-SAME:      dense<0.000000e+00> : vector<32xf16>
+// CHECK-DAG:     %[[CVT1:.*]] = xegpu.convert_layout %[[CST]]
+// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [32]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}>
+// CHECK-SAME:      : vector<32xf16>
+// CHECK:         %[[MR:.*]] = vector.multi_reduction <add>, %[[CVT0]], %[[CVT1]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}
+// CHECK-SAME:      [0] : vector<32x32xf16> to vector<32xf16>
+// CHECK:         return %[[MR]] : vector<32xf16>
+func.func @multireduction_source_conflict() -> vector<32xf16> {
+  %0 = "some_op"() {layout_result_0 = #inst_data_32x16} : () -> vector<32x32xf16>
+  %acc = arith.constant {layout_result_0 = #inst_data_32} dense<0.0> : vector<32xf16>
+  %1 = vector.multi_reduction <add>, %0, %acc
+    {layout_result_0 = #xegpu.slice<#inst_data_16x16, dims = [0]>}
+    [0] : vector<32x32xf16> to vector<32xf16>
+  return %1 : vector<32xf16>
+}
+
+// CHECK-LABEL: func.func @insert_strided_slice_source_conflict
+// CHECK-DAG:     %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} : () -> vector<16x16xf16>
+// CHECK-DAG:     %[[CVT:.*]] = xegpu.convert_layout %[[V0]]
+// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [1, 16]>, target_layout = #xegpu.layout<inst_data = [16, 16]>}>
+// CHECK-SAME:      : vector<16x16xf16>
+// CHECK-DAG:     %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>}
+// CHECK-SAME:      dense<0.000000e+00> : vector<2x32x32xf16>
+// CHECK:         %[[ISS:.*]] = vector.insert_strided_slice %[[CVT]], %[[CST]]
+// CHECK-SAME:      {layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>, offsets = [0, 0, 0], strides = [1, 1]}
+// CHECK-SAME:      : vector<16x16xf16> into vector<2x32x32xf16>
+// CHECK:         return %[[ISS]] : vector<2x32x32xf16>
+func.func @insert_strided_slice_source_conflict() -> vector<2x32x32xf16> {
+  %0 = "some_op"()  {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} : () -> vector<16x16xf16>
+  %1 = arith.constant  { layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>}
+    dense<0.0> : vector<2x32x32xf16>
+  %2 = vector.insert_strided_slice %0, %1 {offsets = [0, 0, 0], strides = [1, 1],
+    layout_result_0 = #xegpu.layout<inst_data = [1, 16, 16]>} : vector<16x16xf16> into vector<2x32x32xf16>
+  return %2: vector<2x32x32xf16>
+}
+
+// CHECK-LABEL: func.func @conflict_inside_loop
+// CHECK-DAG:     %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>}
+// CHECK-SAME:      dense<0.000000e+00> : vector<16x16xf16>
+// CHECK:         %[[FOR:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ACC:.*]] = %[[CST]]) -> (vector<16x16xf16>) {
+// CHECK:           %[[V1:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : () -> vector<16x16xf16>
+// CHECK:           %[[CVT:.*]] = xegpu.convert_layout %[[V1]]
+// CHECK-SAME:        <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}>
+// CHECK-SAME:        : vector<16x16xf16>
+// CHECK:           %[[ADD:.*]] = arith.addf %[[ACC]], %[[CVT]]
+// CHECK-SAME:        {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<16x16xf16>
+// CHECK:           scf.yield %[[ADD]] : vector<16x16xf16>
+// CHECK:         }
+// CHECK:         return %[[FOR]] : vector<16x16xf16>
+func.func @conflict_inside_loop() -> vector<16x16xf16> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant {layout_result_0 = #inst_data_8x16} dense<0.0> : vector<16x16xf16>
+  %0 = scf.for %i = %c0 to %c4 step %c1 iter_args(%acc = %cst) -> vector<16x16xf16> {
+    %1 = "some_op"() {layout_result_0 = #inst_data_16x16} : () -> vector<16x16xf16>
+    %2 = arith.addf %acc, %1 {layout_result_0 = #inst_data_8x16} : vector<16x16xf16>
+    scf.yield %2 : vector<16x16xf16>
+  } {layout_result_0 = #inst_data_8x16}
+  return %0 : vector<16x16xf16>
+}
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 33af2c5b33d89..dc30110a65133 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -424,9 +424,8 @@ struct TestXeGPUResolveLayoutConflicts
       default;
 
   void runOnOperation() override {
-    if (failed(xegpu::resolveLayoutConflicts(getOperation()))) {
+    if (failed(xegpu::resolveLayoutConflicts(getOperation())))
       signalPassFailure();
-    }
   }
 };
 

>From 26680437ee9c69a3f0e3451f469e859fdbbdff01 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 19 Feb 2026 23:32:07 +0000
Subject: [PATCH 3/4] add tests

---
 mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
index 4b601f9f85b3a..d28698b497f25 100644
--- a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
+++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt --test-xegpu-resolve-layout-conflicts="layout-kind=inst" \
-// RUN: -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-xegpu-resolve-layout-conflicts -allow-unregistered-dialect -split-input-file %s | FileCheck %s
 
 #inst_data_8x16 = #xegpu.layout<inst_data = [8, 16]>
 #inst_data_16x16 = #xegpu.layout<inst_data = [16, 16]>

>From 2b845bfdfbdf97c386f451db5e64c4614ca0ba24 Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Thu, 19 Feb 2026 23:46:45 +0000
Subject: [PATCH 4/4] add tests

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 56b3656fc044c..66083e7002b0f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -6,8 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <utility>
-
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Analysis/DataFlow/Utils.h"
@@ -24,9 +22,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/Operation.h"
-#include "mlir/IR/Region.h"
 #include "mlir/IR/Value.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"



More information about the Mlir-commits mailing list