[Mlir-commits] [mlir] [MLIR][XeGPU] Recover temporary layout from Anchor Layout (PR #191947)

Jianhui Li llvmlistbot at llvm.org
Mon Apr 20 20:23:40 PDT 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/191947

>From 3a6c2fe41fa7953ca42e94e5663231b33052ce00 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 2 Apr 2026 22:42:50 +0000
Subject: [PATCH 01/21] initial implementation

---
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |   5 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 473 ++++++++++++++++--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  77 ++-
 3 files changed, 499 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 9cf9a8705209b..5f46eab7b74c7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -183,10 +183,13 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
                 VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
                 const uArch::uArch *uArch);
 
+DistributeLayoutAttr
+inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
+
 /// Gets the expected layout for a given consumer operand. This will check if
 /// the owning operation of the consumer operand is one of the special layout
 /// users and determine the expected layout accordingly.
-xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
+DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
 
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 55cd6ec04970c..06cd0eaa0059e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -18,16 +18,22 @@
 #include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
 
+#define DEBUG_TYPE "xegpu-layout-recovery"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
 using namespace mlir;
 
 void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
@@ -80,32 +86,330 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
   return out;
 }
 
-// Attach layout attributes to all vector-type operands of operations within
-// the given operation's region. Reports an error if any vector operand lacks
-// a layout attribute.
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-  auto result = rootOp->walk([&](Operation *op) {
-    for (OpOperand &operand : op->getOpOperands()) {
-      // Layouts are needed for vector type only.
-      if (!isa<VectorType>(operand.get().getType()))
-        continue;
-      // Skip block arguments since they don't have defining ops to attach
-      // layout attributes to.
-      if (isa<BlockArgument>(operand.get()))
+// Prerequisite for Layout Recovery
+// It relies on the following invariant:
+// 1. there is no layout conflict between different uses of the same definition.
+// 2. each definition has a well-defined layout requirement at its use point.
+//     - Every definition must have at least one use that appears after it in
+//     topological order.
+//     - If a definition has no such use (e.g., a loop result or region output),
+//     an explicit convert_layout operation is inserted to create a use.
+//     - Only the result of convert_layout is permitted to have no subsequent
+//     use.
+
+// The recovery proceeds by scanning the operation in reverse topological order
+// as follows:
+//    For regular operations: First the result layouts are propagated from uses.
+//      Then the result layouts are propagated to operands.
+//
+//    For region operations (e.g., loops):
+//       - When backward propagation reaches a region op, it sets the layout of
+//       the region op’s results according to use points like regular ops.
+//       - Then, the result layouts (such as a loop output) are propagated to
+//       their corresponding operands in the yield.
+//       - When backward propagation reaches the first operation inside the
+//       region, the pass examines the region op’s initialization list,
+//       propagating from region arguments to the corresponding initialization
+//       operands.
+//       - This ensures that layouts are consistently propagated
+//       across region boundaries while preserving a single well-defined use for
+//       each definition at the region-op level.
+
+// the inner function for recoverTemporaryLayouts is a recursive function
+// the input rootOp is the function operation, which is also a region op.
+// it recursivley process the region op in reverse topological order.
+
+static void walkRegionBackward(Region &region,
+                               llvm::function_ref<void(Operation *)> visit) {
+  // blocks: back -> front
+  for (Block &block : llvm::reverse(region)) {
+    // ops: back -> front, early-inc so visit() may erase current op safely
+    for (Operation &op : llvm::reverse(block)) {
+      // make sure we first visit inside the region op (so yield op first)
+      // and then move to region op itself
+      for (Region &nested : llvm::reverse(op.getRegions()))
+        walkRegionBackward(nested, visit);
+
+      visit(&op);
+    }
+  }
+}
+
+static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
+  xegpu::DistributeLayoutAttr layout = nullptr;
+  for (OpOperand &use : result.getUses()) {
+    if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
+      // debug print the use and op, and the tmpLayout
+      LLVM_DEBUG({
+        DBGS() << "      use: " << use.getOwner()->getName() << use.getOwner();
+        llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
+      });
+      // under debug mode, we want to check all the use points to make sure
+      // there is no conflict, so we do not break here. In release mode, we can
+      // break at the first use
+#ifndef NDEBUG
+      assert(!layout || layout == tmpLayout);
+      layout = tmpLayout;
+#else
+      layout = tmpLayout;
+      break;
+#endif
+    }
+  }
+  return layout;
+}
+
+// For regular operations: First the result layouts are propagated from uses.
+// Then the result layouts are propagated to uses (operands).
+static void propagateResultsToRegularOperands(Operation *op) {
+  LLVM_DEBUG(DBGS() << "propagateResultsToRegularOperands: " << op->getName()
+                    << " (" << op->getNumOperands() << " operands, "
+                    << op->getNumResults() << " results)\n");
+
+  if (op->getNumResults() == 0) {
+    LLVM_DEBUG(DBGS() << "  skipping (no results)\n");
+    return;
+  }
+
+  Value result = op->getResult(0);
+  xegpu::DistributeLayoutAttr resLayout =
+      getLayoutFromUsePoints(op->getResult(0));
+  Type resultType = result.getType();
+
+  // recover layout for tensor Descriptor type, which is a special case since
+  // its layout is not stored as an attribute but encoded in the type itself.
+  // For vector type, we attach the layout as an attribute to op.
+  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+    auto typeWithLayout = xegpu::TensorDescType::get(
+        tensorDescTy.getContext(), tensorDescTy.getShape(),
+        tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
+    result.setType(typeWithLayout);
+  }
+
+  for (OpOperand &opr : op->getOpOperands()) {
+    // Layouts are needed for vector type only.
+    xegpu::DistributeLayoutAttr operandLayout =
+        xegpu::inferSourceLayoutFromResult(opr, resLayout);
+    if (!isa<VectorType>(opr.get().getType())) {
+      LLVM_DEBUG(DBGS() << "  operand #" << opr.getOperandNumber()
+                        << ": skipped (non-vector type: " << opr.get().getType()
+                        << ")\n");
+      continue;
+    }
+
+    xegpu::setTemporaryLayout(opr, operandLayout);
+    // debug print op
+    LLVM_DEBUG(DBGS() << "after propagateResultsToRegularOperands  op: "
+                      << op->getName() << op << "  operand #"
+                      << opr.getOperandNumber()
+                      << ": type=" << opr.get().getType());
+    llvm::dbgs() << ", temp Layout=" << xegpu::getTemporaryLayout(opr);
+    llvm::dbgs() << "\n";
+  }
+}
+
+static void propagateRegionResultsToYieldOperands(
+    mlir::RegionBranchTerminatorOpInterface yieldOp) {
+  LLVM_DEBUG(DBGS() << "propagateRegionResultsToYieldOperands: "
+                    << yieldOp->getName() << " (" << yieldOp->getNumOperands()
+                    << " operands), parent="
+                    << yieldOp->getParentOp()->getName() << "\n");
+
+  if (func::FuncOp func = dyn_cast<func::FuncOp>(yieldOp->getParentOp())) {
+    LLVM_DEBUG(DBGS() << "  skipping (parent is FuncOp)\n");
+    return;
+  }
+  llvm::SmallVector<mlir::RegionSuccessor> successors;
+  llvm::SmallVector<mlir::Attribute> operands(yieldOp->getNumOperands(),
+                                              nullptr);
+  yieldOp.getSuccessorRegions(operands, successors);
+
+  auto regionBranchOp = cast<RegionBranchOpInterface>(yieldOp->getParentOp());
+
+  LLVM_DEBUG(DBGS() << "  found " << successors.size() << " successors\n");
+  for (mlir::RegionSuccessor &successor : successors) {
+    // debug print out successorr
+    LLVM_DEBUG({
+      DBGS() << "  successor: ";
+      if (successor.isParent()) {
+        DBGS() << "(parent operation)";
+      } else {
+        DBGS() << "region with " << successor.getSuccessor()->getNumArguments()
+               << " arguments";
+      }
+      DBGS() << "\n";
+    });
+    // find out the successor which is the parent region of yieldOp
+    // if (successor.getSuccessor() != yieldOp->getParentRegion()) {
+    //   LLVM_DEBUG(DBGS() << "  skipping successor (not parent region)\n");
+    //   continue;
+    // }
+    if (!successor.isParent())
+      continue;
+    // propagate the layout from region result to yield operands
+    ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
+    LLVM_DEBUG(DBGS() << "  propagating " << successorInputs.size()
+                      << " region results to yield operands\n");
+    for (unsigned i = 0; i < successorInputs.size(); ++i) {
+      Value regionResult = successorInputs[i];
+
+      // debug print regionResult
+      LLVM_DEBUG({
+        DBGS() << " before propagateRegionResultsToYieldOperands, Region IR:";
+        DBGS() << "    region result #" << i
+               << ": type=" << regionResult.getType();
+        llvm::dbgs() << regionResult;
+        llvm::dbgs() << "\n";
+      });
+      // find all the use of region result, and propagate the layout to the
+      // corresponding yield operand for all use of region result, get its
+      // layout from temporary operand layout if any of these use have it
+      xegpu::DistributeLayoutAttr layout = getLayoutFromUsePoints(regionResult);
+
+      // auto layout = xegpu::getDistributeLayoutAttr(regionResult);
+      if (layout == nullptr) {
+        LLVM_DEBUG(DBGS() << "    region result #" << i
+                          << ": skipped (no layout)\n");
         continue;
-      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
-      if (!layout) {
-        op->emitWarning("Could not find layout attribute for operand ")
-            << operand.getOperandNumber() << " of operation " << op->getName();
+      }
+      assert(
+          layout &&
+          "region result layout must be defined before propagating to yield");
+
+      if (auto opResult = dyn_cast<OpResult>(regionResult))
+        xegpu::setTemporaryLayout(opResult, layout);
+      xegpu::setTemporaryLayout(yieldOp->getOpOperand(i), layout);
+
+      LLVM_DEBUG({
+        DBGS() << " after propagateRegionResultsToYieldOperands, Region IR:";
+        regionResult.print(llvm::dbgs());
+        if (Operation *defOp = regionResult.getDefiningOp())
+          defOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+        llvm::dbgs() << "\n";
+      });
+    }
+  }
+}
+
+static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
+  LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
+                    << " (" << regionOp->getNumOperands() << " operands, "
+                    << regionOp->getNumRegions() << " regions)\n");
+  DBGS() << " before propagateRegionArgsToInits, Region IR:";
+  regionOp.print(llvm::dbgs());
+  DBGS() << " complex debug Region IR:";
+  regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+  // Get entry successors (regions that can be entered initially)
+  SmallVector<RegionSuccessor> successors;
+  regionOp.getEntrySuccessorRegions(/*operands=*/ArrayRef<Attribute>(),
+                                    successors);
+
+  LLVM_DEBUG(DBGS() << "  found " << successors.size()
+                    << " entry successors\n");
+  // For each possible entry region, get the operands forwarded to it
+  for (RegionSuccessor &successor : successors) {
+    OperandRange initOperands = regionOp.getEntrySuccessorOperands(successor);
+    unsigned beginIdx = initOperands.getBeginOperandIndex();
+    unsigned numArgs = successor.getSuccessor()->getNumArguments();
+    LLVM_DEBUG(DBGS() << "  successor region: " << numArgs
+                      << " args, initOperands beginIdx=" << beginIdx
+                      << ", count=" << initOperands.size() << "\n");
+    // initOperands are the initialization arguments for this successor
+    // iterate the region arguments
+    for (unsigned i = 0; i < numArgs; ++i) {
+      Value regionArg =
+          successor.getSuccessor()->getArgument(i); // region argument
+      auto layout = xegpu::getDistributeLayoutAttr(regionArg);
+      if (layout == nullptr) {
+        LLVM_DEBUG(DBGS() << "    region argument #" << i
+                          << ": skipped (no layout)\n");
         continue;
       }
-      xegpu::setTemporaryLayout(operand, layout);
+      assert(
+          layout &&
+          "region argument layout must be defined before propagating to init");
+      LLVM_DEBUG(DBGS() << "    regionArg #" << i << ": type="
+                        << regionArg.getType() << ", layout=" << layout
+                        << " -> init operand #" << (beginIdx + i) << "\n");
+      xegpu::setTemporaryLayout(regionOp->getOpOperand(beginIdx + i), layout);
     }
-    return WalkResult::advance();
+  }
+  DBGS() << " after propagateRegionArgsToInits, Region IR:";
+  regionOp.print(llvm::dbgs());
+}
+
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+  LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts START ===\n");
+
+  auto processFunc = [&](Region &body, StringRef funcName) {
+    LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
+    walkRegionBackward(body, [&](Operation *op) {
+      LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
+      if (op->getNumResults() > 0) {
+        LLVM_DEBUG(llvm::dbgs() << " [results: " << op->getNumResults());
+        for (OpResult res : op->getResults()) {
+          auto layout = xegpu::getDistributeLayoutAttr(res);
+          LLVM_DEBUG(llvm::dbgs() << " r#" << res.getResultNumber() << "="
+                                  << (layout ? layout : nullptr));
+        }
+        LLVM_DEBUG(llvm::dbgs() << "]");
+      }
+      LLVM_DEBUG(llvm::dbgs() << "\n");
+      if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+        // hit the region op after visiting inside region
+        LLVM_DEBUG(DBGS() << "  -> dispatching as RegionBranchOp\n");
+        propagateRegionArgsToInits(regionOp);
+      } else if (auto yieldOp =
+                     dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
+        // yield op inside region op
+        LLVM_DEBUG(DBGS() << "  -> dispatching as YieldOp\n");
+        propagateRegionResultsToYieldOperands(yieldOp);
+      } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+        // if the op is regular op, calling propagateResultsToRegularOperands
+        LLVM_DEBUG(DBGS() << "  -> dispatching as regular op\n");
+        propagateResultsToRegularOperands(op);
+      }
+    });
+  };
+
+  rootOp->walk([&](func::FuncOp func) {
+    processFunc(func.getBody(), func.getSymName());
   });
-  return !result.wasInterrupted();
+  rootOp->walk([&](gpu::GPUFuncOp func) {
+    processFunc(func.getBody(), func.getName());
+  });
+
+  LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
+  return true;
 }
 
+// // Attach layout attributes to all vector-type operands of operations within
+// // the given operation's region. Reports an error if any vector operand lacks
+// // a layout attribute.
+// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+//   auto result = rootOp->walk([&](Operation *op) {
+//     for (OpOperand &operand : op->getOpOperands()) {
+//       // Layouts are needed for vector type only.
+//       if (!isa<VectorType>(operand.get().getType()))
+//         continue;
+//       // Skip block arguments since they don't have defining ops to attach
+//       // layout attributes to.
+//       if (isa<BlockArgument>(operand.get()))
+//         continue;
+//       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+//       if (!layout) {
+//         op->emitWarning("Could not find layout attribute for operand ")
+//             << operand.getOperandNumber() << " of operation " <<
+//             op->getName();
+//         xegpu::setTemporaryLayout(operand, layout);
+//         continue;
+//       }
+//     }
+//     return WalkResult::advance();
+//   });
+//   return !result.wasInterrupted();
+// }
+
 template <typename T, typename>
 void xegpu::removeLayoutAttr(const T &operandOrResult) {
   Operation *owner = operandOrResult.getOwner();
@@ -1108,99 +1412,178 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
   return std::nullopt;
 }
 
-xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
+xegpu::DistributeLayoutAttr
+xegpu::inferSourceLayoutFromResult(OpOperand &operand,
+                                   xegpu::DistributeLayoutAttr resLayout) {
   Operation *op = operand.getOwner();
   unsigned idx = operand.getOperandNumber();
-  xegpu::DistributeLayoutAttr resLayout;
-  if (op->getNumResults() == 1)
-    resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
 
   // For vector::BroadcastOp, infer the source layout from the result layout.
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
-    if (!resLayout)
+    LLVM_DEBUG(DBGS() << "  -> BroadcastOp\n");
+    if (!resLayout) {
+      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
       return xegpu::DistributeLayoutAttr();
+    }
     auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
-    if (!srcTy)
+    if (!srcTy) {
+      LLVM_DEBUG(DBGS() << "     source is not VectorType, returning null\n");
       return xegpu::DistributeLayoutAttr();
-    return xegpu::inferBroadcastSourceLayout(
+    }
+    auto inferred = xegpu::inferBroadcastSourceLayout(
         resLayout, broadcast.getResultVectorType().getShape(),
         srcTy.getShape());
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For vector::MultiDimReductionOp, infer source layout from result layout
   // using reduction dims. Acc operand is expected to have the same layout as
   // the result.
   if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
-    if (!resLayout)
+    LLVM_DEBUG(DBGS() << "  -> MultiDimReductionOp, operand idx=" << idx
+                      << "\n");
+    if (!resLayout) {
+      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
       return xegpu::DistributeLayoutAttr();
+    }
     if (idx == 0) {
       SmallVector<int64_t> reductionDims(reduction.getReductionDims());
-      return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+      LLVM_DEBUG({
+        DBGS() << "     reductionDims=[";
+        llvm::interleaveComma(reductionDims, llvm::dbgs());
+        llvm::dbgs() << "]\n";
+      });
+      auto inferred =
+          xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+      LLVM_DEBUG(DBGS() << "     inferred source layout=" << inferred << "\n");
+      return inferred;
     }
-    if (idx == 1)
+    if (idx == 1) {
+      LLVM_DEBUG(DBGS() << "     acc operand, using resLayout\n");
       return resLayout;
+    }
   }
 
   if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
-    if (!resLayout)
+    LLVM_DEBUG(DBGS() << "  -> ReductionOp\n");
+    if (!resLayout) {
+      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
       return xegpu::DistributeLayoutAttr();
-    return xegpu::inferReductionSourceLayout(resLayout);
+    }
+    auto inferred = xegpu::inferReductionSourceLayout(resLayout);
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For vector::BitCastOp, infer source layout from result layout using
   // element type bitwidths.
   if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
-    if (!resLayout)
+    LLVM_DEBUG(DBGS() << "  -> BitCastOp\n");
+    if (!resLayout) {
+      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
       return xegpu::DistributeLayoutAttr();
+    }
     int resElemBitWidth =
         bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
     int srcElemBitWidth =
         bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
-    return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
-                                           srcElemBitWidth);
+    LLVM_DEBUG(DBGS() << "     resBitWidth=" << resElemBitWidth
+                      << ", srcBitWidth=" << srcElemBitWidth << "\n");
+    auto inferred = xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+                                                    srcElemBitWidth);
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For vector::ShapeCastOp, infer source layout from result layout using
   // shapes.
   if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
-    if (!resLayout)
+    LLVM_DEBUG({
+      DBGS() << "  -> ShapeCastOp: resShape=[";
+      llvm::interleaveComma(shapeCast.getResultVectorType().getShape(),
+                            llvm::dbgs());
+      llvm::dbgs() << "], srcShape=[";
+      llvm::interleaveComma(shapeCast.getSourceVectorType().getShape(),
+                            llvm::dbgs());
+      llvm::dbgs() << "]\n";
+    });
+    if (!resLayout) {
+      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
       return xegpu::DistributeLayoutAttr();
-    return xegpu::inferShapeCastSourceLayout(
+    }
+    auto inferred = xegpu::inferShapeCastSourceLayout(
         resLayout, shapeCast.getResultVectorType().getShape(),
         shapeCast.getSourceVectorType().getShape());
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For vector::InsertStridedSliceOp, infer source layout from result layout.
   // Dest vector must have the same layout as the result.
   if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
-    if (!resLayout)
+    LLVM_DEBUG(DBGS() << "  -> InsertStridedSliceOp, operand idx=" << idx
+                      << "\n");
+    if (!resLayout) {
+      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
       return xegpu::DistributeLayoutAttr();
-    if (idx == 0)
-      return xegpu::inferInsertStridedSliceSourceLayout(
+    }
+    if (idx == 0) {
+      auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
           resLayout, insertSlice.getDestVectorType().getShape(),
           insertSlice.getSourceVectorType().getShape());
-    if (idx == 1)
+      LLVM_DEBUG(DBGS() << "     inferred source layout=" << inferred << "\n");
+      return inferred;
+    }
+    if (idx == 1) {
+      LLVM_DEBUG(DBGS() << "     dest operand, using resLayout\n");
       return resLayout;
+    }
   }
 
   // For vector::TransposeOp, infer source layout from result layout using
   // permutation.
   if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
-    if (!resLayout)
+    LLVM_DEBUG({
+      DBGS() << "  -> TransposeOp, perm=[";
+      llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
+      llvm::dbgs() << "]\n";
+    });
+    if (!resLayout) {
+      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
       return xegpu::DistributeLayoutAttr();
-    return xegpu::inferTransposeSourceLayout(resLayout,
-                                             transpose.getPermutation());
+    }
+    auto inferred = xegpu::inferTransposeSourceLayout(
+        resLayout, transpose.getPermutation());
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For elementwise operations, all operands must have the same layout as the
   // result.
   if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
+    LLVM_DEBUG(DBGS() << "  -> elementwise op, using resLayout="
+                      << (resLayout ? resLayout : nullptr) << "\n");
     if (!resLayout)
       return xegpu::DistributeLayoutAttr();
     return resLayout;
   }
-  // TODO: Handle more cases as needed here.
+  return xegpu::DistributeLayoutAttr();
+}
+
+xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
+  Operation *op = operand.getOwner();
+  xegpu::DistributeLayoutAttr resLayout;
+  if (op->getNumResults() == 1)
+    resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+  auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
+  if (inferredOperandLayout)
+    return inferredOperandLayout;
   // By default, assume no layout conflict and return the current layout of
   // the operand.
-  return xegpu::getDistributeLayoutAttr(operand.get());
+  auto fallback = xegpu::getDistributeLayoutAttr(operand);
+  LLVM_DEBUG(DBGS() << "  -> fallback (unhandled op " << op->getName()
+                    << "), returning operand layout="
+                    << (fallback ? fallback : nullptr) << "\n");
+  return fallback;
 }
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 243581b4ce522..a762458105e47 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -23,10 +23,14 @@
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
 
+#define DEBUG_TYPE "xegpu-utils"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
 using namespace mlir;
 
 /// convert ArrayRef<ValueRange> into SmallVector<Value>
@@ -145,19 +149,31 @@ std::string xegpu::getTemporaryLayoutName(const OpResult result) {
 }
 
 xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
-  if (!value)
+  LLVM_DEBUG(DBGS() << "getDistributeLayoutAttr(Value): type="
+                    << value.getType() << "\n");
+  if (!value) {
+    LLVM_DEBUG(DBGS() << "  -> null value, returning nullptr\n");
     return nullptr;
+  }
 
   if (auto tdescTy =
-          dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
-    return tdescTy.getLayoutAttr();
+          dyn_cast_if_present<xegpu::TensorDescType>(value.getType())) {
+    auto layout = tdescTy.getLayoutAttr();
+    LLVM_DEBUG(DBGS() << "  -> TensorDescType, layout="
+                      << (layout ? layout : nullptr) << "\n");
+    return layout;
+  }
 
   if (auto result = dyn_cast<OpResult>(value)) {
     Operation *defOp = result.getDefiningOp();
     assert(defOp && "result must have a defining op");
+    LLVM_DEBUG(DBGS() << "  OpResult #" << result.getResultNumber() << " from "
+                      << defOp->getName() << "\n");
 
     if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
       auto layout = anchorOp.getAnchorLayout();
+      LLVM_DEBUG(DBGS() << "  -> AnchorLayoutInterface, layout="
+                        << (layout ? layout : nullptr) << "\n");
       return layout;
     }
 
@@ -165,59 +181,100 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
     if (defOp->hasAttr(layoutName)) {
       auto layout =
           defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+      LLVM_DEBUG(DBGS() << "  -> temporary attr '" << layoutName
+                        << "', layout=" << layout << "\n");
       return layout;
     }
+    LLVM_DEBUG(DBGS() << "  -> OpResult: no layout found (checked '"
+                      << layoutName << "')\n");
   }
 
   if (auto arg = dyn_cast<BlockArgument>(value)) {
     auto *parentOp = arg.getOwner()->getParentOp();
+    LLVM_DEBUG(DBGS() << "  BlockArgument #" << arg.getArgNumber() << " of "
+                      << (parentOp ? parentOp->getName().getStringRef()
+                                   : StringRef("(null)"))
+                      << "\n");
     if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
       OpOperand *tiedInit = loop.getTiedLoopInit(arg);
-      if (tiedInit)
+      if (tiedInit) {
+        LLVM_DEBUG(DBGS() << "  -> LoopLikeOp, recursing into tiedInit "
+                          << "operand #" << tiedInit->getOperandNumber()
+                          << "\n");
         return getDistributeLayoutAttr(tiedInit->get());
+      }
+      LLVM_DEBUG(DBGS() << "  -> LoopLikeOp, no tiedInit\n");
     }
   }
 
+  LLVM_DEBUG(DBGS() << "  -> returning nullptr\n");
   return nullptr;
 }
 xegpu::DistributeLayoutAttr
 xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   Operation *op = opr.getOwner();
   unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
+  LLVM_DEBUG(DBGS() << "getDistributeLayoutAttr(OpOperand): operand #" << idx
+                    << " of " << op->getName()
+                    << ", type=" << opr.get().getType() << "\n");
 
   if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
     if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
       if (idx == 0) {
-        return dpasOp.getLayoutAAttr();
+        auto layout = dpasOp.getLayoutAAttr();
+        LLVM_DEBUG(DBGS() << "  -> DpasOp layoutA="
+                          << (layout ? layout : nullptr) << "\n");
+        return layout;
       } else if (idx == 1) {
-        return dpasOp.getLayoutBAttr();
+        auto layout = dpasOp.getLayoutBAttr();
+        LLVM_DEBUG(DBGS() << "  -> DpasOp layoutB="
+                          << (layout ? layout : nullptr) << "\n");
+        return layout;
       } else if (idx == 2) {
-        return dpasOp.getLayoutCdAttr();
+        auto layout = dpasOp.getLayoutCdAttr();
+        LLVM_DEBUG(DBGS() << "  -> DpasOp layoutCd="
+                          << (layout ? layout : nullptr) << "\n");
+        return layout;
       }
     }
     if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
-      return convertOp.getInputLayoutAttr();
+      auto layout = convertOp.getInputLayoutAttr();
+      LLVM_DEBUG(DBGS() << "  -> ConvertLayoutOp inputLayout="
+                        << (layout ? layout : nullptr) << "\n");
+      return layout;
     }
     auto layout = anchorOp.getAnchorLayout();
 
-    if (idx == 0)
+    if (idx == 0) {
+      LLVM_DEBUG(DBGS() << "  -> AnchorLayoutInterface idx=0, layout="
+                        << (layout ? layout : nullptr) << "\n");
       return layout;
+    }
 
     // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
     // the layout is valid for the first two operands: value and memref/tdesc.
     // For other operations, the layout applies to the first operand only.
     if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
             op) &&
-        (idx < 2))
+        (idx < 2)) {
+      LLVM_DEBUG(DBGS() << "  -> Store op idx=" << idx
+                        << ", layout=" << (layout ? layout : nullptr) << "\n");
       return layout;
+    }
+    LLVM_DEBUG(DBGS() << "  -> AnchorLayoutInterface idx=" << idx
+                      << " not covered, falling through\n");
   }
 
   std::string layoutName = xegpu::getTemporaryLayoutName(opr);
   if (op->hasAttr(layoutName)) {
     auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+    LLVM_DEBUG(DBGS() << "  -> temporary attr '" << layoutName
+                      << "', layout=" << layout << "\n");
     return layout;
   }
 
+  LLVM_DEBUG(DBGS() << "  -> returning nullptr (checked '" << layoutName
+                    << "')\n");
   return nullptr;
 }
 

>From f77e110d9dc81257b2deeb9cccb20e10bea3739b Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 05:31:42 +0000
Subject: [PATCH 02/21] pass while

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 253 +++++++-----------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |   2 +-
 2 files changed, 103 insertions(+), 152 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 06cd0eaa0059e..47148870eeaae 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -147,13 +147,8 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
       // under debug mode, we want to check all the use points to make sure
       // there is no conflict, so we do not break here. In release mode, we can
       // break at the first use
-#ifndef NDEBUG
-      assert(!layout || layout == tmpLayout);
-      layout = tmpLayout;
-#else
-      layout = tmpLayout;
-      break;
-#endif
+      if (!layout)
+        layout = tmpLayout;
     }
   }
   return layout;
@@ -215,127 +210,118 @@ static void propagateRegionResultsToYieldOperands(
                     << " operands), parent="
                     << yieldOp->getParentOp()->getName() << "\n");
 
-  if (func::FuncOp func = dyn_cast<func::FuncOp>(yieldOp->getParentOp())) {
+  if (isa<func::FuncOp>(yieldOp->getParentOp())) {
     LLVM_DEBUG(DBGS() << "  skipping (parent is FuncOp)\n");
     return;
   }
-  llvm::SmallVector<mlir::RegionSuccessor> successors;
-  llvm::SmallVector<mlir::Attribute> operands(yieldOp->getNumOperands(),
-                                              nullptr);
-  yieldOp.getSuccessorRegions(operands, successors);
 
-  auto regionBranchOp = cast<RegionBranchOpInterface>(yieldOp->getParentOp());
+  auto regionBranchOp =
+      dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
+  if (!regionBranchOp) {
+    LLVM_DEBUG(DBGS() << "  skipping (parent is not RegionBranchOp)\n");
+    return;
+  }
 
-  LLVM_DEBUG(DBGS() << "  found " << successors.size() << " successors\n");
-  for (mlir::RegionSuccessor &successor : successors) {
-    // debug print out successorr
-    LLVM_DEBUG({
-      DBGS() << "  successor: ";
-      if (successor.isParent()) {
-        DBGS() << "(parent operation)";
-      } else {
-        DBGS() << "region with " << successor.getSuccessor()->getNumArguments()
-               << " arguments";
-      }
-      DBGS() << "\n";
-    });
-    // find out the successor which is the parent region of yieldOp
-    // if (successor.getSuccessor() != yieldOp->getParentRegion()) {
-    //   LLVM_DEBUG(DBGS() << "  skipping successor (not parent region)\n");
-    //   continue;
-    // }
-    if (!successor.isParent())
-      continue;
-    // propagate the layout from region result to yield operands
-    ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
-    LLVM_DEBUG(DBGS() << "  propagating " << successorInputs.size()
-                      << " region results to yield operands\n");
-    for (unsigned i = 0; i < successorInputs.size(); ++i) {
-      Value regionResult = successorInputs[i];
-
-      // debug print regionResult
-      LLVM_DEBUG({
-        DBGS() << " before propagateRegionResultsToYieldOperands, Region IR:";
-        DBGS() << "    region result #" << i
-               << ": type=" << regionResult.getType();
-        llvm::dbgs() << regionResult;
-        llvm::dbgs() << "\n";
-      });
-      // find all the use of region result, and propagate the layout to the
-      // corresponding yield operand for all use of region result, get its
-      // layout from temporary operand layout if any of these use have it
-      xegpu::DistributeLayoutAttr layout = getLayoutFromUsePoints(regionResult);
-
-      // auto layout = xegpu::getDistributeLayoutAttr(regionResult);
-      if (layout == nullptr) {
-        LLVM_DEBUG(DBGS() << "    region result #" << i
-                          << ": skipped (no layout)\n");
-        continue;
-      }
-      assert(
-          layout &&
-          "region result layout must be defined before propagating to yield");
+  // Gather layouts for each result of the parent region op from external
+  // use points.
+  unsigned numResults = regionBranchOp->getNumResults();
+  LLVM_DEBUG(DBGS() << "  parent op has " << numResults << " results\n");
+
+  SmallVector<xegpu::DistributeLayoutAttr> resultLayouts(numResults, nullptr);
+  for (unsigned i = 0; i < numResults; ++i) {
+    OpResult result = regionBranchOp->getResult(i);
+    resultLayouts[i] = getLayoutFromUsePoints(result);
+    if (resultLayouts[i]) {
+      LLVM_DEBUG(DBGS() << "  result #" << i << ": type=" << result.getType()
+                        << ", layout=" << resultLayouts[i] << "\n");
+      xegpu::setTemporaryLayout(result, resultLayouts[i]);
+    } else {
+      LLVM_DEBUG(DBGS() << "  result #" << i
+                        << ": skipped (no layout from use points)\n");
+    }
+  }
 
-      if (auto opResult = dyn_cast<OpResult>(regionResult))
-        xegpu::setTemporaryLayout(opResult, layout);
-      xegpu::setTemporaryLayout(yieldOp->getOpOperand(i), layout);
+  // Use getSuccessorOperands to find which operands of the terminator
+  // flow to a successor. This handles index offsets automatically (e.g.,
+  // scf.condition's predicate at operand #0 is excluded).
+  // Pick the first successor to determine the operand range.
+  SmallVector<RegionSuccessor> successors;
+  SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
+  yieldOp.getSuccessorRegions(operandAttrs, successors);
+  assert(!successors.empty() && "terminator must have at least one successor");
 
-      LLVM_DEBUG({
-        DBGS() << " after propagateRegionResultsToYieldOperands, Region IR:";
-        regionResult.print(llvm::dbgs());
-        if (Operation *defOp = regionResult.getDefiningOp())
-          defOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-        llvm::dbgs() << "\n";
-      });
-    }
+  OperandRange succOps = yieldOp.getSuccessorOperands(successors.front());
+  unsigned beginIdx = succOps.getBeginOperandIndex();
+  unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
+
+  LLVM_DEBUG(DBGS() << "  " << count << " successor operands starting at index "
+                    << beginIdx << "\n");
+
+  for (unsigned i = 0; i < count; ++i) {
+    if (!resultLayouts[i])
+      continue;
+    LLVM_DEBUG(DBGS() << "    -> setting layout on operand #" << (beginIdx + i)
+                      << "\n");
+    xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i),
+                              resultLayouts[i]);
   }
+
+  LLVM_DEBUG({
+    DBGS() << " after propagateRegionResultsToYieldOperands:\n";
+    yieldOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+    llvm::dbgs() << "\n";
+  });
 }
 
 static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
   LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
                     << " (" << regionOp->getNumOperands() << " operands, "
                     << regionOp->getNumRegions() << " regions)\n");
-  DBGS() << " before propagateRegionArgsToInits, Region IR:";
-  regionOp.print(llvm::dbgs());
-  DBGS() << " complex debug Region IR:";
-  regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-  // Get entry successors (regions that can be entered initially)
-  SmallVector<RegionSuccessor> successors;
-  regionOp.getEntrySuccessorRegions(/*operands=*/ArrayRef<Attribute>(),
-                                    successors);
-
-  LLVM_DEBUG(DBGS() << "  found " << successors.size()
-                    << " entry successors\n");
-  // For each possible entry region, get the operands forwarded to it
-  for (RegionSuccessor &successor : successors) {
-    OperandRange initOperands = regionOp.getEntrySuccessorOperands(successor);
-    unsigned beginIdx = initOperands.getBeginOperandIndex();
-    unsigned numArgs = successor.getSuccessor()->getNumArguments();
-    LLVM_DEBUG(DBGS() << "  successor region: " << numArgs
-                      << " args, initOperands beginIdx=" << beginIdx
-                      << ", count=" << initOperands.size() << "\n");
-    // initOperands are the initialization arguments for this successor
-    // iterate the region arguments
-    for (unsigned i = 0; i < numArgs; ++i) {
-      Value regionArg =
-          successor.getSuccessor()->getArgument(i); // region argument
-      auto layout = xegpu::getDistributeLayoutAttr(regionArg);
-      if (layout == nullptr) {
-        LLVM_DEBUG(DBGS() << "    region argument #" << i
-                          << ": skipped (no layout)\n");
+  LLVM_DEBUG({
+    DBGS() << " before propagateRegionArgsToInits, Region IR:\n";
+    regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+    llvm::dbgs() << "\n";
+  });
+
+  // Iterate all regions of the region op. For each block argument that has a
+  // layout (determined from its use points), trace back to find the
+  // corresponding init operand of the regionOp and set the layout on it.
+  // This works generically for scf.for, scf.while, and other
+  // RegionBranchOpInterface ops.
+  for (Region &region : regionOp->getRegions()) {
+    RegionSuccessor regionSuccessor(&region);
+    for (auto [argIdx, regionArg] : llvm::enumerate(region.getArguments())) {
+      auto layout = getLayoutFromUsePoints(regionArg);
+      if (!layout) {
+        LLVM_DEBUG(DBGS() << "  region #" << region.getRegionNumber()
+                          << " arg #" << argIdx << ": skipped (no layout)\n");
         continue;
       }
-      assert(
-          layout &&
-          "region argument layout must be defined before propagating to init");
-      LLVM_DEBUG(DBGS() << "    regionArg #" << i << ": type="
-                        << regionArg.getType() << ", layout=" << layout
-                        << " -> init operand #" << (beginIdx + i) << "\n");
-      xegpu::setTemporaryLayout(regionOp->getOpOperand(beginIdx + i), layout);
+      LLVM_DEBUG(DBGS() << "  region #" << region.getRegionNumber() << " arg #"
+                        << argIdx << ": type=" << regionArg.getType()
+                        << ", layout=" << layout << "\n");
+
+      // Find all predecessor values that flow into this block argument.
+      SmallVector<Value> predValues;
+      regionOp.getPredecessorValues(regionSuccessor, argIdx, predValues);
+      for (Value predVal : predValues) {
+        // Match predecessor value to an operand of the regionOp.
+        for (OpOperand &operand : regionOp->getOpOperands()) {
+          if (operand.get() == predVal) {
+            LLVM_DEBUG(DBGS() << "    -> setting layout on init operand #"
+                              << operand.getOperandNumber() << "\n");
+            xegpu::setTemporaryLayout(operand, layout);
+          }
+        }
+      }
     }
   }
-  DBGS() << " after propagateRegionArgsToInits, Region IR:";
-  regionOp.print(llvm::dbgs());
+
+  LLVM_DEBUG({
+    DBGS() << " after propagateRegionArgsToInits, Region IR:\n";
+    regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+    llvm::dbgs() << "\n";
+  });
 }
 
 bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
@@ -345,16 +331,6 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
     LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
     walkRegionBackward(body, [&](Operation *op) {
       LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
-      if (op->getNumResults() > 0) {
-        LLVM_DEBUG(llvm::dbgs() << " [results: " << op->getNumResults());
-        for (OpResult res : op->getResults()) {
-          auto layout = xegpu::getDistributeLayoutAttr(res);
-          LLVM_DEBUG(llvm::dbgs() << " r#" << res.getResultNumber() << "="
-                                  << (layout ? layout : nullptr));
-        }
-        LLVM_DEBUG(llvm::dbgs() << "]");
-      }
-      LLVM_DEBUG(llvm::dbgs() << "\n");
       if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
         // hit the region op after visiting inside region
         LLVM_DEBUG(DBGS() << "  -> dispatching as RegionBranchOp\n");
@@ -1415,16 +1391,16 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
 xegpu::DistributeLayoutAttr
 xegpu::inferSourceLayoutFromResult(OpOperand &operand,
                                    xegpu::DistributeLayoutAttr resLayout) {
+  if (!resLayout) {
+    LLVM_DEBUG(DBGS() << "no resLayout, returning null\n");
+    return xegpu::DistributeLayoutAttr();
+  }
   Operation *op = operand.getOwner();
   unsigned idx = operand.getOperandNumber();
 
   // For vector::BroadcastOp, infer the source layout from the result layout.
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
     LLVM_DEBUG(DBGS() << "  -> BroadcastOp\n");
-    if (!resLayout) {
-      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
-      return xegpu::DistributeLayoutAttr();
-    }
     auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
     if (!srcTy) {
       LLVM_DEBUG(DBGS() << "     source is not VectorType, returning null\n");
@@ -1443,10 +1419,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
   if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
     LLVM_DEBUG(DBGS() << "  -> MultiDimReductionOp, operand idx=" << idx
                       << "\n");
-    if (!resLayout) {
-      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
-      return xegpu::DistributeLayoutAttr();
-    }
     if (idx == 0) {
       SmallVector<int64_t> reductionDims(reduction.getReductionDims());
       LLVM_DEBUG({
@@ -1467,10 +1439,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
 
   if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
     LLVM_DEBUG(DBGS() << "  -> ReductionOp\n");
-    if (!resLayout) {
-      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
-      return xegpu::DistributeLayoutAttr();
-    }
     auto inferred = xegpu::inferReductionSourceLayout(resLayout);
     LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
     return inferred;
@@ -1480,10 +1448,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
   // element type bitwidths.
   if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
     LLVM_DEBUG(DBGS() << "  -> BitCastOp\n");
-    if (!resLayout) {
-      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
-      return xegpu::DistributeLayoutAttr();
-    }
     int resElemBitWidth =
         bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
     int srcElemBitWidth =
@@ -1508,10 +1472,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
                             llvm::dbgs());
       llvm::dbgs() << "]\n";
     });
-    if (!resLayout) {
-      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
-      return xegpu::DistributeLayoutAttr();
-    }
     auto inferred = xegpu::inferShapeCastSourceLayout(
         resLayout, shapeCast.getResultVectorType().getShape(),
         shapeCast.getSourceVectorType().getShape());
@@ -1524,10 +1484,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
   if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
     LLVM_DEBUG(DBGS() << "  -> InsertStridedSliceOp, operand idx=" << idx
                       << "\n");
-    if (!resLayout) {
-      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
-      return xegpu::DistributeLayoutAttr();
-    }
     if (idx == 0) {
       auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
           resLayout, insertSlice.getDestVectorType().getShape(),
@@ -1549,10 +1505,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
       llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
       llvm::dbgs() << "]\n";
     });
-    if (!resLayout) {
-      LLVM_DEBUG(DBGS() << "     no resLayout, returning null\n");
-      return xegpu::DistributeLayoutAttr();
-    }
     auto inferred = xegpu::inferTransposeSourceLayout(
         resLayout, transpose.getPermutation());
     LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
@@ -1564,8 +1516,7 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
   if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
     LLVM_DEBUG(DBGS() << "  -> elementwise op, using resLayout="
                       << (resLayout ? resLayout : nullptr) << "\n");
-    if (!resLayout)
-      return xegpu::DistributeLayoutAttr();
+
     return resLayout;
   }
   return xegpu::DistributeLayoutAttr();
@@ -1581,7 +1532,7 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
     return inferredOperandLayout;
   // By default, assume no layout conflict and return the current layout of
   // the operand.
-  auto fallback = xegpu::getDistributeLayoutAttr(operand);
+  auto fallback = xegpu::getDistributeLayoutAttr(operand.get());
   LLVM_DEBUG(DBGS() << "  -> fallback (unhandled op " << op->getName()
                     << "), returning operand layout="
                     << (fallback ? fallback : nullptr) << "\n");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4c30dacae8850..f0ff771f4cbc4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1338,7 +1338,7 @@ LogicalResult ResolveLayoutConflicts::run() {
     // as anchor op for the reduction op's layout.
     if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
       for (OpResult result : op->getResults()) {
-        if (result.getType().isIntOrFloat()) {
+        if (result.getType().isIntOrFloat() || result.use_empty()) {
           auto res = assignResultLayout(result);
           if (failed(res)) {
             DBGS() << "Failed to resolve vector consumer for multi-reduction "

>From 27cc56acf41eb3380d3195fca5a9215b4414a413 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 18:00:50 +0000
Subject: [PATCH 03/21] adding support for DistributeLayoutAttr in TensorDesc
 instead of just LayoutAttr

---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  6 +++---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  5 +++--
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 13 ++++++------
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 21 +++++++++++++------
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     | 11 +++++++---
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  9 ++++----
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  2 +-
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  4 ++--
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 13 ++++++------
 9 files changed, 49 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 7e142b20c0894..b13f5a9f2c9d9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -82,7 +82,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     static-dim-list ::= decimal-literal `x` decimal-literal
     attr-list = (, encoding-attr)? (, layout-attr)?
     enconding-attr = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
-    layout-attr = (, layout `<`sg_layout = value, sg_data = value, inst_data = value, lane_layout = value, lane_data = value, order = value`>`)?
+    layout-attr = DistributeLayoutAttr
     ```
 
     Examples:
@@ -158,8 +158,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return llvm::dyn_cast_if_present<T>(getEncoding());
     }
 
-    LayoutAttr getLayoutAttr() const {
-      return llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
+    DistributeLayoutAttr getLayoutAttr() const {
+      return llvm::dyn_cast_if_present<DistributeLayoutAttr>(getLayout());
     }
 
     xegpu::MemorySpace getMemorySpace() const {
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 0aa2cd45088f3..1b594f17e15ec 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -219,10 +219,11 @@ void setTemporaryLayout(const T &operandOrResult,
 /// Helper function to check if the layout is packed. Layout is packed if it is
 /// 2D and lane_data[0] != 1 (data packed from col dimension).
 /// TODO: Move to target info.
-bool requirePacked(const LayoutAttr layout);
+bool requirePacked(const DistributeLayoutAttr layout);
 
 /// Helper function to check if the layout requires a transpose effect.
-bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch);
+bool requireTranspose(const DistributeLayoutAttr layout,
+                      const uArch::uArch *uArch);
 
 // Check if dst shape is an expansion of src shape by inserting unit dimensions.
 bool matchUnitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 1ee0bc6ad9507..ef6a494b76638 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -270,12 +270,11 @@ void XeGPUBlockingPass::runOnOperation() {
   }
 
   auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
-                                 xegpu::LayoutAttr layout) {
+                                 xegpu::DistributeLayoutAttr layout) {
     int count = 1;
     SmallVector<int64_t> tileShape(shape);
-    if (layout && layout.getInstData()) {
-      DenseI32ArrayAttr instData = layout.getInstData();
-      tileShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
+    if (layout && !layout.getEffectiveInstDataAsInt().empty()) {
+      tileShape = layout.getEffectiveInstDataAsInt();
       count = computeProduct(shape) / computeProduct(tileShape);
     }
     return std::make_pair(tileShape, count);
@@ -308,7 +307,7 @@ void XeGPUBlockingPass::runOnOperation() {
         Type elemTy = type.getElementType();
         ArrayRef<int64_t> shape = type.getShape();
 
-        xegpu::LayoutAttr layout = type.getLayoutAttr();
+        xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
         if (layout && layout.isForWorkgroup())
           return failure();
 
@@ -348,9 +347,9 @@ void XeGPUBlockingPass::runOnOperation() {
 
         if (chunkSize > 1) {
           int64_t blockedChunkSize = chunkSize;
-          auto instData = tdescTy.getLayoutAttr().getInstData();
+          auto instData = tdescTy.getLayoutAttr().getEffectiveInstDataAsInt();
           if (!instData.empty())
-            blockedChunkSize = instData.asArrayRef().back();
+            blockedChunkSize = instData.back();
 
           // To create a new attribute with a different chunk_size:
           auto newEncoding = xegpu::ScatterTensorDescAttr::get(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 47148870eeaae..535239e869af1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -141,7 +141,8 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
     if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
       // debug print the use and op, and the tmpLayout
       LLVM_DEBUG({
-        DBGS() << "      use: " << use.getOwner()->getName() << use.getOwner();
+        DBGS() << "getLayoutFromUsePoints  use: " << use.getOwner()->getName()
+               << use.getOwner();
         llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
       });
       // under debug mode, we want to check all the use points to make sure
@@ -175,10 +176,16 @@ static void propagateResultsToRegularOperands(Operation *op) {
   // its layout is not stored as an attribute but encoded in the type itself.
   // For vector type, we attach the layout as an attribute to op.
   if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
-    auto typeWithLayout = xegpu::TensorDescType::get(
-        tensorDescTy.getContext(), tensorDescTy.getShape(),
-        tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
-    result.setType(typeWithLayout);
+    auto layout = tensorDescTy.getLayoutAttr();
+    // TODO: remove the layout check. The tensorDescType's layout is treated as
+    // temporary layout, which needs to be set by layout recovery.
+    // allow it now to pass some legacy test case
+    if (!layout) {
+      auto typeWithLayout = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
+      result.setType(typeWithLayout);
+    }
   }
 
   for (OpOperand &opr : op->getOpOperands()) {
@@ -226,6 +233,8 @@ static void propagateRegionResultsToYieldOperands(
   // use points.
   unsigned numResults = regionBranchOp->getNumResults();
   LLVM_DEBUG(DBGS() << "  parent op has " << numResults << " results\n");
+  if (numResults == 0)
+    return;
 
   SmallVector<xegpu::DistributeLayoutAttr> resultLayouts(numResults, nullptr);
   for (unsigned i = 0; i < numResults; ++i) {
@@ -303,7 +312,7 @@ static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
 
       // Find all predecessor values that flow into this block argument.
       SmallVector<Value> predValues;
-      regionOp.getPredecessorValues(regionSuccessor, argIdx, predValues);
+      regionOp.getPredecessorValues(regionSuccessor, argIdx - 1, predValues);
       for (Value predVal : predValues) {
         // Match predecessor value to an operand of the regionOp.
         for (OpOperand &operand : regionOp->getOpOperands()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 0ece695aed512..9288ba9a0cb56 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -145,10 +145,15 @@ static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
     return tdescType;
 
   SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
+  auto ctx = tdescType.getContext();
+  auto origLayout = tdescType.getLayoutAttr();
+  SmallVector<int32_t> laneLayoutI32(
+      origLayout.getEffectiveLaneLayoutAsInt().begin(),
+      origLayout.getEffectiveLaneLayoutAsInt().end());
   xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
-      tdescType.getContext(), tdescType.getLayoutAttr().getLaneLayout(),
-      DenseI32ArrayAttr::get(tdescType.getContext(), {1, 1}),
-      tdescType.getLayoutAttr().getOrder());
+      ctx, /*lane_layout=*/DenseI32ArrayAttr::get(ctx, laneLayoutI32),
+      /*lane_data=*/DenseI32ArrayAttr::get(ctx, {1, 1}),
+      /*order=*/origLayout.getOrder());
   // Array length can not be larger than 1 for transpose case.
   return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
                                     tdescType.getBoundaryCheck(),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index ecdf253d68182..d8ce24ddd5cb0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -256,7 +256,7 @@ struct CreateNdDescDistribution final : public gpu::WarpDistributionPattern {
     auto descOp = operand->get().getDefiningOp<xegpu::CreateNdDescOp>();
     unsigned operandIdx = operand->getOperandNumber();
 
-    xegpu::LayoutAttr layout = descOp.getType().getLayoutAttr();
+    xegpu::DistributeLayoutAttr layout = descOp.getType().getLayoutAttr();
     if (!layout)
       return rewriter.notifyMatchFailure(
           descOp, "the tensor descriptor lacks layout attribute");
@@ -342,7 +342,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
     SmallVector<Type> offsetTypes = llvm::map_to_vector(
         offsetsAsValues, [](Value v) { return v.getType(); });
     xegpu::TensorDescType tensorDescTy = storeOp.getTensorDescType();
-    xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
+    xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
     if (!layout)
       return rewriter.notifyMatchFailure(
           storeOp, "the source tensor descriptor lacks layout attribute");
@@ -474,7 +474,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
         offsetsAsValues, [](Value v) { return v.getType(); });
 
     xegpu::TensorDescType tensorDescTy = loadOp.getTensorDescType();
-    xegpu::LayoutAttr layout = tensorDescTy.getLayoutAttr();
+    xegpu::DistributeLayoutAttr layout = tensorDescTy.getLayoutAttr();
     if (!layout)
       return rewriter.notifyMatchFailure(
           loadOp, "the source tensor descriptor lacks layout attribute");
@@ -709,7 +709,8 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern {
     SmallVector<Type> offsetTypes = llvm::map_to_vector(
         offsetsAsValues, [](Value v) { return v.getType(); });
 
-    xegpu::LayoutAttr layout = prefetchOp.getTensorDescType().getLayoutAttr();
+    xegpu::DistributeLayoutAttr layout =
+        prefetchOp.getTensorDescType().getLayoutAttr();
     if (!layout)
       return rewriter.notifyMatchFailure(
           prefetchOp, "the source tensor descriptor lacks layout attribute");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0aead9172858f..e47224bbe755c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1647,7 +1647,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   converter.addConversion(
       [&](xegpu::TensorDescType type,
           SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
-        xegpu::LayoutAttr layout = type.getLayoutAttr();
+        xegpu::DistributeLayoutAttr layout = type.getLayoutAttr();
         // Only convert WG-level tensor descs. SG-level or layout-less types
         // are already legal and should pass through unchanged.
         if (!layout || !layout.isForWorkgroup())
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index a762458105e47..55cf47e38dfd0 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -936,7 +936,7 @@ template int
 xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
                                    ArrayRef<unsigned> candidateMultiples);
 
-bool xegpu::requirePacked(const xegpu::LayoutAttr layout) {
+bool xegpu::requirePacked(const xegpu::DistributeLayoutAttr layout) {
   if (!layout)
     return false;
   auto laneData = layout.getEffectiveLaneDataAsInt();
@@ -945,7 +945,7 @@ bool xegpu::requirePacked(const xegpu::LayoutAttr layout) {
   return laneData[0] != 1;
 }
 
-bool xegpu::requireTranspose(const xegpu::LayoutAttr layout,
+bool xegpu::requireTranspose(const xegpu::DistributeLayoutAttr layout,
                              const xegpu::uArch::uArch *uArch) {
   // Return false for unsupported targets.
   // TODO: Add more support or move to target info.
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 0d10ab7c74da6..4760016bdcea4 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -106,10 +106,9 @@ struct TestXeGPUUnrollingPatterns
         }
 
         if (auto layout = tdescTy.getLayoutAttr()) {
-          auto inst_data = layout.getInstData();
-          if (inst_data && layout.isForSubgroup())
-            return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                        inst_data.asArrayRef().end());
+          auto inst_data = layout.getEffectiveInstDataAsInt();
+          if (!inst_data.empty() && layout.isForSubgroup())
+            return SmallVector<int64_t>(inst_data.begin(), inst_data.end());
         }
       }
 
@@ -138,9 +137,9 @@ struct TestXeGPUUnrollingPatterns
 
               if (chunkSize > 1) {
                 int64_t blockedChunkSize = chunkSize;
-                auto instData = layout.getInstData();
+                auto instData = layout.getEffectiveInstDataAsInt();
                 if (!instData.empty())
-                  blockedChunkSize = instData.asArrayRef().back();
+                  blockedChunkSize = instData.back();
 
                 // To create a new attribute with a different chunk_size:
                 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
@@ -150,7 +149,7 @@ struct TestXeGPUUnrollingPatterns
               }
             }
             if (layout) {
-              if (layout.getLaneLayout() == nullptr)
+              if (layout.getEffectiveLaneLayoutAsInt().empty())
                 layout = xegpu::LayoutAttr();
               else
                 layout = layout.dropInstData();

>From 0690c6cc01e121b137c1056e62d22ff207a82777 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:02:43 +0000
Subject: [PATCH 04/21] fix bugs

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  2 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |  6 +++
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     | 19 ++++++++--
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 38 ++++++++++++++++++-
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 19 +---------
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   |  4 +-
 6 files changed, 64 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 950371e17255f..64c56b5adf5d7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -1318,7 +1318,7 @@ mlir::Type TensorDescType::parse(AsmParser &parser) {
     mlir::Attribute attr;
     ParseResult res = parser.parseAttribute(attr);
     if (mlir::succeeded(res)) {
-      if (mlir::isa<LayoutAttr>(attr)) {
+      if (mlir::isa<DistributeLayoutAttr>(attr)) {
         layout = attr;
         continue;
       }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 535239e869af1..33c9086566d3c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -365,6 +365,12 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
   });
 
   LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
+  // print the root op after
+  LLVM_DEBUG({
+    DBGS() << "After recoverTemporaryLayouts, IR:\n";
+    rootOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+    llvm::dbgs() << "\n";
+  });
   return true;
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 9288ba9a0cb56..c43eaba5b3ee6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -28,6 +28,7 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
 #include <optional>
 
 namespace mlir {
@@ -147,13 +148,25 @@ static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
   SmallVector<int64_t> supportedShape = {supportedHeight, supportedWidth};
   auto ctx = tdescType.getContext();
   auto origLayout = tdescType.getLayoutAttr();
-  SmallVector<int32_t> laneLayoutI32(
-      origLayout.getEffectiveLaneLayoutAsInt().begin(),
-      origLayout.getEffectiveLaneLayoutAsInt().end());
+  auto laneLayoutI64 = origLayout.getEffectiveLaneLayoutAsInt();
+  SmallVector<int32_t> laneLayoutI32(laneLayoutI64.begin(),
+                                     laneLayoutI64.end());
+  LLVM_DEBUG({
+    DBGS() << "tryOptimize: origLayout=" << origLayout << "\n";
+    DBGS() << "  laneLayoutI32=[";
+    llvm::interleaveComma(laneLayoutI32, llvm::dbgs());
+    llvm::dbgs() << "], laneData=[1, 1]";
+    if (origLayout.getOrder())
+      llvm::dbgs() << ", order=" << origLayout.getOrder();
+    llvm::dbgs() << "\n";
+    DBGS() << "  supportedShape=[" << supportedHeight << ", " << supportedWidth
+           << "], newElemTy=" << newElemTy << ", arrayLen=" << arrayLen << "\n";
+  });
   xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
       ctx, /*lane_layout=*/DenseI32ArrayAttr::get(ctx, laneLayoutI32),
       /*lane_data=*/DenseI32ArrayAttr::get(ctx, {1, 1}),
       /*order=*/origLayout.getOrder());
+  LLVM_DEBUG(DBGS() << "  newLayout=" << newLayout << "\n");
   // Array length can not be larger than 1 for transpose case.
   return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
                                     tdescType.getBoundaryCheck(),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d8ce24ddd5cb0..27cf788933f18 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -800,10 +800,17 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(DBGS() << "StoreDistribution: attempting to match\n");
     Operation *lastNode = warpOp.getTerminator()->getPrevNode();
     auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
-    if (!storeScatterOp)
+    if (!storeScatterOp) {
+      LLVM_DEBUG(
+          DBGS()
+          << "StoreDistribution: last node is not StoreScatterOp, skipping\n");
       return failure();
+    }
+    LLVM_DEBUG(DBGS() << "StoreDistribution: matched StoreScatterOp: "
+                      << *storeScatterOp << "\n");
     auto offsets = storeScatterOp.getOffsets();
     if (!offsets || !isa<VectorType>(offsets.getType()))
       return rewriter.notifyMatchFailure(
@@ -811,10 +818,15 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
     VectorType offsetsTy = cast<VectorType>(offsets.getType());
     VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
     VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
+    LLVM_DEBUG(DBGS() << "StoreDistribution: offsetsTy=" << offsetsTy
+                      << ", maskTy=" << maskTy << ", storeVecTy=" << storeVecTy
+                      << "\n");
 
     // Add handling for leading unit dimensions support
     int chunkSize = storeScatterOp.getChunkSize().value_or(1);
     int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+    LLVM_DEBUG(DBGS() << "StoreDistribution: chunkSize=" << chunkSize
+                      << ", effectiveVecRank=" << effectiveVecRank << "\n");
 
     // Check that all leading dimensions are unit dimensions
     for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
@@ -831,6 +843,24 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
         xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
     auto layoutMask =
         xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
+    LLVM_DEBUG({
+      DBGS() << "StoreDistribution: layoutPayload=";
+      if (layoutPayload)
+        DBGS() << layoutPayload;
+      else
+        DBGS() << "(null)";
+      DBGS() << ", layoutOffsets=";
+      if (layoutOffsets)
+        DBGS() << layoutOffsets;
+      else
+        DBGS() << "(null)";
+      DBGS() << ", layoutMask=";
+      if (layoutMask)
+        DBGS() << layoutMask;
+      else
+        DBGS() << "(null)";
+      DBGS() << "\n";
+    });
 
     FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
@@ -849,6 +879,9 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
     VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
     VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
     VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
+    LLVM_DEBUG(DBGS() << "StoreDistribution: distPayloadTy=" << distPayloadTy
+                      << ", distOffsetsTy=" << distOffsetsTy
+                      << ", distMaskTy=" << distMaskTy << "\n");
 
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands = storeScatterOp->getOperands();
@@ -885,7 +918,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
         rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
         storeScatterOp->getAttrs());
     xegpu::removeLayoutAttrs(newOp);
+    LLVM_DEBUG(DBGS() << "StoreDistribution: created new op: " << newOp
+                      << "\n");
     rewriter.eraseOp(storeScatterOp);
+    LLVM_DEBUG(DBGS() << "StoreDistribution: done\n");
     return success();
   }
 };
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 842c2375dd31d..0d1bfd5480aa2 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -473,22 +473,6 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
   gpu.return
 }
 
-// CHECK-LABEL: gpu.func @vector_transpose
-// CHECK:         %[[SRC:.*]] = "some_op"()
-// CHECK:         %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16x2xf32> to vector<1x2xf32>
-// CHECK-NEXT:    %[[T:.*]] = vector.transpose %[[CAST]], [1, 0] : vector<1x2xf32> to vector<2x1xf32>
-// CHECK-NEXT:    gpu.return
-gpu.func @vector_transpose() {
-  %cst = "some_op"()
-    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}
-    : () -> (vector<16x2xf32>)
-  %transpose = vector.transpose %cst, [1, 0]
-    {
-      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-    }
-    : vector<16x2xf32> to vector<2x16xf32>
-  gpu.return
-}
 
 // CHECK-LABEL: gpu.func @vector_bitcast
 // CHECK:         %[[SRC:.*]] = "some_op"()
@@ -1092,7 +1076,8 @@ gpu.module @xevm_module {
 gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) {
   %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x1xf16>
   %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
-  "some_use"(%1) : (vector<16x16xf16>) -> ()
+  %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>
+  "some_use"(%2) : (vector<16x16xf16>) -> ()
   gpu.return
 }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 9ca424374335f..61b8046bd04e5 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -257,7 +257,7 @@ gpu.module @test_kernel {
 
 // -----
 #l = #xegpu.layout<inst_data = [16, 16]>
-#r = #xegpu.layout<inst_data = [16]>
+#r = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>
 gpu.module @test_kernel  {
   gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
     %acc = arith.constant {layout_result_0 = #r} dense<0.0> : vector<64xf32>
@@ -277,7 +277,7 @@ gpu.module @test_kernel  {
 
 // -----
 #l = #xegpu.layout<inst_data = [16, 16]>
-#r = #xegpu.layout<inst_data = [16]>
+#r = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [1]>
 gpu.module @test_kernel   {
   gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
     %c1 = arith.constant 1 : index

>From ac36ceaccbd9bff10bf933ffef9b0b0d1e557cdc Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:24:35 +0000
Subject: [PATCH 05/21] separate recover temporary layout out to another PR

---
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |   5 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 457 +++---------------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |   2 +-
 .../XeGPU/sg-to-wi-experimental-unit.mlir     |  19 +-
 4 files changed, 73 insertions(+), 410 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 5f46eab7b74c7..9cf9a8705209b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -183,13 +183,10 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
                 VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
                 const uArch::uArch *uArch);
 
-DistributeLayoutAttr
-inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
-
 /// Gets the expected layout for a given consumer operand. This will check if
 /// the owning operation of the consumer operand is one of the special layout
 /// users and determine the expected layout accordingly.
-DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
+xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
 
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 33c9086566d3c..55cd6ec04970c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -18,22 +18,16 @@
 #include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/ValueRange.h"
-#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
 
-#define DEBUG_TYPE "xegpu-layout-recovery"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
 using namespace mlir;
 
 void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
@@ -86,321 +80,32 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
   return out;
 }
 
-// Prerequisite for Layout Recovery
-// It relies on the following invariant:
-// 1. there is no layout conflict between different uses of the same definition.
-// 2. each definition has a well-defined layout requirement at its use point.
-//     - Every definition must have at least one use that appears after it in
-//     topological order.
-//     - If a definition has no such use (e.g., a loop result or region output),
-//     an explicit convert_layout operation is inserted to create a use.
-//     - Only the result of convert_layout is permitted to have no subsequent
-//     use.
-
-// The recovery proceeds by scanning the operation in reverse topological order
-// as follows:
-//    For regular operations: First the result layouts are propagated from uses.
-//      Then the result layouts are propagated to operands.
-//
-//    For region operations (e.g., loops):
-//       - When backward propagation reaches a region op, it sets the layout of
-//       the region op’s results according to use points like regular ops.
-//       - Then, the result layouts (such as a loop output) are propagated to
-//       their corresponding operands in the yield.
-//       - When backward propagation reaches the first operation inside the
-//       region, the pass examines the region op’s initialization list,
-//       propagating from region arguments to the corresponding initialization
-//       operands.
-//       - This ensures that layouts are consistently propagated
-//       across region boundaries while preserving a single well-defined use for
-//       each definition at the region-op level.
-
-// the inner function for recoverTemporaryLayouts is a recursive function
-// the input rootOp is the function operation, which is also a region op.
-// it recursivley process the region op in reverse topological order.
-
-static void walkRegionBackward(Region &region,
-                               llvm::function_ref<void(Operation *)> visit) {
-  // blocks: back -> front
-  for (Block &block : llvm::reverse(region)) {
-    // ops: back -> front, early-inc so visit() may erase current op safely
-    for (Operation &op : llvm::reverse(block)) {
-      // make sure we first visit inside the region op (so yield op first)
-      // and then move to region op itself
-      for (Region &nested : llvm::reverse(op.getRegions()))
-        walkRegionBackward(nested, visit);
-
-      visit(&op);
-    }
-  }
-}
-
-static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
-  xegpu::DistributeLayoutAttr layout = nullptr;
-  for (OpOperand &use : result.getUses()) {
-    if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
-      // debug print the use and op, and the tmpLayout
-      LLVM_DEBUG({
-        DBGS() << "getLayoutFromUsePoints  use: " << use.getOwner()->getName()
-               << use.getOwner();
-        llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
-      });
-      // under debug mode, we want to check all the use points to make sure
-      // there is no conflict, so we do not break here. In release mode, we can
-      // break at the first use
-      if (!layout)
-        layout = tmpLayout;
-    }
-  }
-  return layout;
-}
-
-// For regular operations: First the result layouts are propagated from uses.
-// Then the result layouts are propagated to uses (operands).
-static void propagateResultsToRegularOperands(Operation *op) {
-  LLVM_DEBUG(DBGS() << "propagateResultsToRegularOperands: " << op->getName()
-                    << " (" << op->getNumOperands() << " operands, "
-                    << op->getNumResults() << " results)\n");
-
-  if (op->getNumResults() == 0) {
-    LLVM_DEBUG(DBGS() << "  skipping (no results)\n");
-    return;
-  }
-
-  Value result = op->getResult(0);
-  xegpu::DistributeLayoutAttr resLayout =
-      getLayoutFromUsePoints(op->getResult(0));
-  Type resultType = result.getType();
-
-  // recover layout for tensor Descriptor type, which is a special case since
-  // its layout is not stored as an attribute but encoded in the type itself.
-  // For vector type, we attach the layout as an attribute to op.
-  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
-    auto layout = tensorDescTy.getLayoutAttr();
-    // TODO: remove the layout check. The tensorDescType's layout is treated as
-    // temporary layout, which needs to be set by layout recovery.
-    // allow it now to pass some legacy test case
-    if (!layout) {
-      auto typeWithLayout = xegpu::TensorDescType::get(
-          tensorDescTy.getContext(), tensorDescTy.getShape(),
-          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
-      result.setType(typeWithLayout);
-    }
-  }
-
-  for (OpOperand &opr : op->getOpOperands()) {
-    // Layouts are needed for vector type only.
-    xegpu::DistributeLayoutAttr operandLayout =
-        xegpu::inferSourceLayoutFromResult(opr, resLayout);
-    if (!isa<VectorType>(opr.get().getType())) {
-      LLVM_DEBUG(DBGS() << "  operand #" << opr.getOperandNumber()
-                        << ": skipped (non-vector type: " << opr.get().getType()
-                        << ")\n");
-      continue;
-    }
-
-    xegpu::setTemporaryLayout(opr, operandLayout);
-    // debug print op
-    LLVM_DEBUG(DBGS() << "after propagateResultsToRegularOperands  op: "
-                      << op->getName() << op << "  operand #"
-                      << opr.getOperandNumber()
-                      << ": type=" << opr.get().getType());
-    llvm::dbgs() << ", temp Layout=" << xegpu::getTemporaryLayout(opr);
-    llvm::dbgs() << "\n";
-  }
-}
-
-static void propagateRegionResultsToYieldOperands(
-    mlir::RegionBranchTerminatorOpInterface yieldOp) {
-  LLVM_DEBUG(DBGS() << "propagateRegionResultsToYieldOperands: "
-                    << yieldOp->getName() << " (" << yieldOp->getNumOperands()
-                    << " operands), parent="
-                    << yieldOp->getParentOp()->getName() << "\n");
-
-  if (isa<func::FuncOp>(yieldOp->getParentOp())) {
-    LLVM_DEBUG(DBGS() << "  skipping (parent is FuncOp)\n");
-    return;
-  }
-
-  auto regionBranchOp =
-      dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
-  if (!regionBranchOp) {
-    LLVM_DEBUG(DBGS() << "  skipping (parent is not RegionBranchOp)\n");
-    return;
-  }
-
-  // Gather layouts for each result of the parent region op from external
-  // use points.
-  unsigned numResults = regionBranchOp->getNumResults();
-  LLVM_DEBUG(DBGS() << "  parent op has " << numResults << " results\n");
-  if (numResults == 0)
-    return;
-
-  SmallVector<xegpu::DistributeLayoutAttr> resultLayouts(numResults, nullptr);
-  for (unsigned i = 0; i < numResults; ++i) {
-    OpResult result = regionBranchOp->getResult(i);
-    resultLayouts[i] = getLayoutFromUsePoints(result);
-    if (resultLayouts[i]) {
-      LLVM_DEBUG(DBGS() << "  result #" << i << ": type=" << result.getType()
-                        << ", layout=" << resultLayouts[i] << "\n");
-      xegpu::setTemporaryLayout(result, resultLayouts[i]);
-    } else {
-      LLVM_DEBUG(DBGS() << "  result #" << i
-                        << ": skipped (no layout from use points)\n");
-    }
-  }
-
-  // Use getSuccessorOperands to find which operands of the terminator
-  // flow to a successor. This handles index offsets automatically (e.g.,
-  // scf.condition's predicate at operand #0 is excluded).
-  // Pick the first successor to determine the operand range.
-  SmallVector<RegionSuccessor> successors;
-  SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
-  yieldOp.getSuccessorRegions(operandAttrs, successors);
-  assert(!successors.empty() && "terminator must have at least one successor");
-
-  OperandRange succOps = yieldOp.getSuccessorOperands(successors.front());
-  unsigned beginIdx = succOps.getBeginOperandIndex();
-  unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
-
-  LLVM_DEBUG(DBGS() << "  " << count << " successor operands starting at index "
-                    << beginIdx << "\n");
-
-  for (unsigned i = 0; i < count; ++i) {
-    if (!resultLayouts[i])
-      continue;
-    LLVM_DEBUG(DBGS() << "    -> setting layout on operand #" << (beginIdx + i)
-                      << "\n");
-    xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i),
-                              resultLayouts[i]);
-  }
-
-  LLVM_DEBUG({
-    DBGS() << " after propagateRegionResultsToYieldOperands:\n";
-    yieldOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-    llvm::dbgs() << "\n";
-  });
-}
-
-static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
-  LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
-                    << " (" << regionOp->getNumOperands() << " operands, "
-                    << regionOp->getNumRegions() << " regions)\n");
-  LLVM_DEBUG({
-    DBGS() << " before propagateRegionArgsToInits, Region IR:\n";
-    regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-    llvm::dbgs() << "\n";
-  });
-
-  // Iterate all regions of the region op. For each block argument that has a
-  // layout (determined from its use points), trace back to find the
-  // corresponding init operand of the regionOp and set the layout on it.
-  // This works generically for scf.for, scf.while, and other
-  // RegionBranchOpInterface ops.
-  for (Region &region : regionOp->getRegions()) {
-    RegionSuccessor regionSuccessor(&region);
-    for (auto [argIdx, regionArg] : llvm::enumerate(region.getArguments())) {
-      auto layout = getLayoutFromUsePoints(regionArg);
+// Attach layout attributes to all vector-type operands of operations within
+// the given operation's region. Reports an error if any vector operand lacks
+// a layout attribute.
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+  auto result = rootOp->walk([&](Operation *op) {
+    for (OpOperand &operand : op->getOpOperands()) {
+      // Layouts are needed for vector type only.
+      if (!isa<VectorType>(operand.get().getType()))
+        continue;
+      // Skip block arguments since they don't have defining ops to attach
+      // layout attributes to.
+      if (isa<BlockArgument>(operand.get()))
+        continue;
+      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
       if (!layout) {
-        LLVM_DEBUG(DBGS() << "  region #" << region.getRegionNumber()
-                          << " arg #" << argIdx << ": skipped (no layout)\n");
+        op->emitWarning("Could not find layout attribute for operand ")
+            << operand.getOperandNumber() << " of operation " << op->getName();
         continue;
       }
-      LLVM_DEBUG(DBGS() << "  region #" << region.getRegionNumber() << " arg #"
-                        << argIdx << ": type=" << regionArg.getType()
-                        << ", layout=" << layout << "\n");
-
-      // Find all predecessor values that flow into this block argument.
-      SmallVector<Value> predValues;
-      regionOp.getPredecessorValues(regionSuccessor, argIdx - 1, predValues);
-      for (Value predVal : predValues) {
-        // Match predecessor value to an operand of the regionOp.
-        for (OpOperand &operand : regionOp->getOpOperands()) {
-          if (operand.get() == predVal) {
-            LLVM_DEBUG(DBGS() << "    -> setting layout on init operand #"
-                              << operand.getOperandNumber() << "\n");
-            xegpu::setTemporaryLayout(operand, layout);
-          }
-        }
-      }
+      xegpu::setTemporaryLayout(operand, layout);
     }
-  }
-
-  LLVM_DEBUG({
-    DBGS() << " after propagateRegionArgsToInits, Region IR:\n";
-    regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-    llvm::dbgs() << "\n";
-  });
-}
-
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-  LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts START ===\n");
-
-  auto processFunc = [&](Region &body, StringRef funcName) {
-    LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
-    walkRegionBackward(body, [&](Operation *op) {
-      LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
-      if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
-        // hit the region op after visiting inside region
-        LLVM_DEBUG(DBGS() << "  -> dispatching as RegionBranchOp\n");
-        propagateRegionArgsToInits(regionOp);
-      } else if (auto yieldOp =
-                     dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
-        // yield op inside region op
-        LLVM_DEBUG(DBGS() << "  -> dispatching as YieldOp\n");
-        propagateRegionResultsToYieldOperands(yieldOp);
-      } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
-        // if the op is regular op, calling propagateResultsToRegularOperands
-        LLVM_DEBUG(DBGS() << "  -> dispatching as regular op\n");
-        propagateResultsToRegularOperands(op);
-      }
-    });
-  };
-
-  rootOp->walk([&](func::FuncOp func) {
-    processFunc(func.getBody(), func.getSymName());
-  });
-  rootOp->walk([&](gpu::GPUFuncOp func) {
-    processFunc(func.getBody(), func.getName());
-  });
-
-  LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
-  // print the root op after
-  LLVM_DEBUG({
-    DBGS() << "After recoverTemporaryLayouts, IR:\n";
-    rootOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-    llvm::dbgs() << "\n";
+    return WalkResult::advance();
   });
-  return true;
+  return !result.wasInterrupted();
 }
 
-// // Attach layout attributes to all vector-type operands of operations within
-// // the given operation's region. Reports an error if any vector operand lacks
-// // a layout attribute.
-// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-//   auto result = rootOp->walk([&](Operation *op) {
-//     for (OpOperand &operand : op->getOpOperands()) {
-//       // Layouts are needed for vector type only.
-//       if (!isa<VectorType>(operand.get().getType()))
-//         continue;
-//       // Skip block arguments since they don't have defining ops to attach
-//       // layout attributes to.
-//       if (isa<BlockArgument>(operand.get()))
-//         continue;
-//       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
-//       if (!layout) {
-//         op->emitWarning("Could not find layout attribute for operand ")
-//             << operand.getOperandNumber() << " of operation " <<
-//             op->getName();
-//         xegpu::setTemporaryLayout(operand, layout);
-//         continue;
-//       }
-//     }
-//     return WalkResult::advance();
-//   });
-//   return !result.wasInterrupted();
-// }
-
 template <typename T, typename>
 void xegpu::removeLayoutAttr(const T &operandOrResult) {
   Operation *owner = operandOrResult.getOwner();
@@ -1403,153 +1108,99 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
   return std::nullopt;
 }
 
-xegpu::DistributeLayoutAttr
-xegpu::inferSourceLayoutFromResult(OpOperand &operand,
-                                   xegpu::DistributeLayoutAttr resLayout) {
-  if (!resLayout) {
-    LLVM_DEBUG(DBGS() << "no resLayout, returning null\n");
-    return xegpu::DistributeLayoutAttr();
-  }
+xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
   Operation *op = operand.getOwner();
   unsigned idx = operand.getOperandNumber();
+  xegpu::DistributeLayoutAttr resLayout;
+  if (op->getNumResults() == 1)
+    resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
 
   // For vector::BroadcastOp, infer the source layout from the result layout.
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> BroadcastOp\n");
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
     auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
-    if (!srcTy) {
-      LLVM_DEBUG(DBGS() << "     source is not VectorType, returning null\n");
+    if (!srcTy)
       return xegpu::DistributeLayoutAttr();
-    }
-    auto inferred = xegpu::inferBroadcastSourceLayout(
+    return xegpu::inferBroadcastSourceLayout(
         resLayout, broadcast.getResultVectorType().getShape(),
         srcTy.getShape());
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
   }
 
   // For vector::MultiDimReductionOp, infer source layout from result layout
   // using reduction dims. Acc operand is expected to have the same layout as
   // the result.
   if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> MultiDimReductionOp, operand idx=" << idx
-                      << "\n");
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
     if (idx == 0) {
       SmallVector<int64_t> reductionDims(reduction.getReductionDims());
-      LLVM_DEBUG({
-        DBGS() << "     reductionDims=[";
-        llvm::interleaveComma(reductionDims, llvm::dbgs());
-        llvm::dbgs() << "]\n";
-      });
-      auto inferred =
-          xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
-      LLVM_DEBUG(DBGS() << "     inferred source layout=" << inferred << "\n");
-      return inferred;
+      return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
     }
-    if (idx == 1) {
-      LLVM_DEBUG(DBGS() << "     acc operand, using resLayout\n");
+    if (idx == 1)
       return resLayout;
-    }
   }
 
   if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> ReductionOp\n");
-    auto inferred = xegpu::inferReductionSourceLayout(resLayout);
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    return xegpu::inferReductionSourceLayout(resLayout);
   }
 
   // For vector::BitCastOp, infer source layout from result layout using
   // element type bitwidths.
   if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> BitCastOp\n");
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
     int resElemBitWidth =
         bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
     int srcElemBitWidth =
         bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
-    LLVM_DEBUG(DBGS() << "     resBitWidth=" << resElemBitWidth
-                      << ", srcBitWidth=" << srcElemBitWidth << "\n");
-    auto inferred = xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
-                                                    srcElemBitWidth);
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
+    return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+                                           srcElemBitWidth);
   }
 
   // For vector::ShapeCastOp, infer source layout from result layout using
   // shapes.
   if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
-    LLVM_DEBUG({
-      DBGS() << "  -> ShapeCastOp: resShape=[";
-      llvm::interleaveComma(shapeCast.getResultVectorType().getShape(),
-                            llvm::dbgs());
-      llvm::dbgs() << "], srcShape=[";
-      llvm::interleaveComma(shapeCast.getSourceVectorType().getShape(),
-                            llvm::dbgs());
-      llvm::dbgs() << "]\n";
-    });
-    auto inferred = xegpu::inferShapeCastSourceLayout(
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    return xegpu::inferShapeCastSourceLayout(
         resLayout, shapeCast.getResultVectorType().getShape(),
         shapeCast.getSourceVectorType().getShape());
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
   }
 
   // For vector::InsertStridedSliceOp, infer source layout from result layout.
   // Dest vector must have the same layout as the result.
   if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> InsertStridedSliceOp, operand idx=" << idx
-                      << "\n");
-    if (idx == 0) {
-      auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    if (idx == 0)
+      return xegpu::inferInsertStridedSliceSourceLayout(
           resLayout, insertSlice.getDestVectorType().getShape(),
           insertSlice.getSourceVectorType().getShape());
-      LLVM_DEBUG(DBGS() << "     inferred source layout=" << inferred << "\n");
-      return inferred;
-    }
-    if (idx == 1) {
-      LLVM_DEBUG(DBGS() << "     dest operand, using resLayout\n");
+    if (idx == 1)
       return resLayout;
-    }
   }
 
   // For vector::TransposeOp, infer source layout from result layout using
   // permutation.
   if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
-    LLVM_DEBUG({
-      DBGS() << "  -> TransposeOp, perm=[";
-      llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
-      llvm::dbgs() << "]\n";
-    });
-    auto inferred = xegpu::inferTransposeSourceLayout(
-        resLayout, transpose.getPermutation());
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
+    return xegpu::inferTransposeSourceLayout(resLayout,
+                                             transpose.getPermutation());
   }
 
   // For elementwise operations, all operands must have the same layout as the
   // result.
   if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
-    LLVM_DEBUG(DBGS() << "  -> elementwise op, using resLayout="
-                      << (resLayout ? resLayout : nullptr) << "\n");
-
+    if (!resLayout)
+      return xegpu::DistributeLayoutAttr();
     return resLayout;
   }
-  return xegpu::DistributeLayoutAttr();
-}
-
-xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
-  Operation *op = operand.getOwner();
-  xegpu::DistributeLayoutAttr resLayout;
-  if (op->getNumResults() == 1)
-    resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
-  auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
-  if (inferredOperandLayout)
-    return inferredOperandLayout;
+  // TODO: Handle more cases as needed here.
   // By default, assume no layout conflict and return the current layout of
   // the operand.
-  auto fallback = xegpu::getDistributeLayoutAttr(operand.get());
-  LLVM_DEBUG(DBGS() << "  -> fallback (unhandled op " << op->getName()
-                    << "), returning operand layout="
-                    << (fallback ? fallback : nullptr) << "\n");
-  return fallback;
+  return xegpu::getDistributeLayoutAttr(operand.get());
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index f0ff771f4cbc4..4c30dacae8850 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1338,7 +1338,7 @@ LogicalResult ResolveLayoutConflicts::run() {
     // as anchor op for the reduction op's layout.
     if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
       for (OpResult result : op->getResults()) {
-        if (result.getType().isIntOrFloat() || result.use_empty()) {
+        if (result.getType().isIntOrFloat()) {
           auto res = assignResultLayout(result);
           if (failed(res)) {
             DBGS() << "Failed to resolve vector consumer for multi-reduction "
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 0d1bfd5480aa2..842c2375dd31d 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -473,6 +473,22 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_transpose
+// CHECK:         %[[SRC:.*]] = "some_op"()
+// CHECK:         %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16x2xf32> to vector<1x2xf32>
+// CHECK-NEXT:    %[[T:.*]] = vector.transpose %[[CAST]], [1, 0] : vector<1x2xf32> to vector<2x1xf32>
+// CHECK-NEXT:    gpu.return
+gpu.func @vector_transpose() {
+  %cst = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}
+    : () -> (vector<16x2xf32>)
+  %transpose = vector.transpose %cst, [1, 0]
+    {
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+    : vector<16x2xf32> to vector<2x16xf32>
+  gpu.return
+}
 
 // CHECK-LABEL: gpu.func @vector_bitcast
 // CHECK:         %[[SRC:.*]] = "some_op"()
@@ -1076,8 +1092,7 @@ gpu.module @xevm_module {
 gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) {
   %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x1xf16>
   %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
-  %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>
-  "some_use"(%2) : (vector<16x16xf16>) -> ()
+  "some_use"(%1) : (vector<16x16xf16>) -> ()
   gpu.return
 }
 }

>From 1328c5ff1981598a4ed9ff102f1ac17360cbd6c4 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:37:55 +0000
Subject: [PATCH 06/21] cleanup

---
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     | 15 +---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 38 +---------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 74 +++----------------
 3 files changed, 12 insertions(+), 115 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index c43eaba5b3ee6..c488bca363da6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -28,7 +28,6 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/Debug.h"
 #include <optional>
 
 namespace mlir {
@@ -151,22 +150,12 @@ static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType,
   auto laneLayoutI64 = origLayout.getEffectiveLaneLayoutAsInt();
   SmallVector<int32_t> laneLayoutI32(laneLayoutI64.begin(),
                                      laneLayoutI64.end());
-  LLVM_DEBUG({
-    DBGS() << "tryOptimize: origLayout=" << origLayout << "\n";
-    DBGS() << "  laneLayoutI32=[";
-    llvm::interleaveComma(laneLayoutI32, llvm::dbgs());
-    llvm::dbgs() << "], laneData=[1, 1]";
-    if (origLayout.getOrder())
-      llvm::dbgs() << ", order=" << origLayout.getOrder();
-    llvm::dbgs() << "\n";
-    DBGS() << "  supportedShape=[" << supportedHeight << ", " << supportedWidth
-           << "], newElemTy=" << newElemTy << ", arrayLen=" << arrayLen << "\n";
-  });
+
   xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get(
       ctx, /*lane_layout=*/DenseI32ArrayAttr::get(ctx, laneLayoutI32),
       /*lane_data=*/DenseI32ArrayAttr::get(ctx, {1, 1}),
       /*order=*/origLayout.getOrder());
-  LLVM_DEBUG(DBGS() << "  newLayout=" << newLayout << "\n");
+
   // Array length can not be larger than 1 for transpose case.
   return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen,
                                     tdescType.getBoundaryCheck(),
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 27cf788933f18..d8ce24ddd5cb0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -800,17 +800,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    LLVM_DEBUG(DBGS() << "StoreDistribution: attempting to match\n");
     Operation *lastNode = warpOp.getTerminator()->getPrevNode();
     auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
-    if (!storeScatterOp) {
-      LLVM_DEBUG(
-          DBGS()
-          << "StoreDistribution: last node is not StoreScatterOp, skipping\n");
+    if (!storeScatterOp)
       return failure();
-    }
-    LLVM_DEBUG(DBGS() << "StoreDistribution: matched StoreScatterOp: "
-                      << *storeScatterOp << "\n");
     auto offsets = storeScatterOp.getOffsets();
     if (!offsets || !isa<VectorType>(offsets.getType()))
       return rewriter.notifyMatchFailure(
@@ -818,15 +811,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
     VectorType offsetsTy = cast<VectorType>(offsets.getType());
     VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
     VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
-    LLVM_DEBUG(DBGS() << "StoreDistribution: offsetsTy=" << offsetsTy
-                      << ", maskTy=" << maskTy << ", storeVecTy=" << storeVecTy
-                      << "\n");
 
     // Add handling for leading unit dimensions support
     int chunkSize = storeScatterOp.getChunkSize().value_or(1);
     int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
-    LLVM_DEBUG(DBGS() << "StoreDistribution: chunkSize=" << chunkSize
-                      << ", effectiveVecRank=" << effectiveVecRank << "\n");
 
     // Check that all leading dimensions are unit dimensions
     for (int i = 0; i < storeVecTy.getRank() - effectiveVecRank; i++) {
@@ -843,24 +831,6 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
         xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
     auto layoutMask =
         xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
-    LLVM_DEBUG({
-      DBGS() << "StoreDistribution: layoutPayload=";
-      if (layoutPayload)
-        DBGS() << layoutPayload;
-      else
-        DBGS() << "(null)";
-      DBGS() << ", layoutOffsets=";
-      if (layoutOffsets)
-        DBGS() << layoutOffsets;
-      else
-        DBGS() << "(null)";
-      DBGS() << ", layoutMask=";
-      if (layoutMask)
-        DBGS() << layoutMask;
-      else
-        DBGS() << "(null)";
-      DBGS() << "\n";
-    });
 
     FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
@@ -879,9 +849,6 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
     VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value();
     VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
     VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
-    LLVM_DEBUG(DBGS() << "StoreDistribution: distPayloadTy=" << distPayloadTy
-                      << ", distOffsetsTy=" << distOffsetsTy
-                      << ", distMaskTy=" << distMaskTy << "\n");
 
     SmallVector<size_t> newRetIndices;
     SmallVector<Value> operands = storeScatterOp->getOperands();
@@ -918,10 +885,7 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
         rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
         storeScatterOp->getAttrs());
     xegpu::removeLayoutAttrs(newOp);
-    LLVM_DEBUG(DBGS() << "StoreDistribution: created new op: " << newOp
-                      << "\n");
     rewriter.eraseOp(storeScatterOp);
-    LLVM_DEBUG(DBGS() << "StoreDistribution: done\n");
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 55cf47e38dfd0..bcac517937754 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -23,14 +23,10 @@
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/Casting.h"
-#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
 
-#define DEBUG_TYPE "xegpu-utils"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
 using namespace mlir;
 
 /// convert ArrayRef<ValueRange> into SmallVector<Value>
@@ -149,31 +145,19 @@ std::string xegpu::getTemporaryLayoutName(const OpResult result) {
 }
 
 xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
-  LLVM_DEBUG(DBGS() << "getDistributeLayoutAttr(Value): type="
-                    << value.getType() << "\n");
-  if (!value) {
-    LLVM_DEBUG(DBGS() << "  -> null value, returning nullptr\n");
+  if (!value)
     return nullptr;
-  }
 
   if (auto tdescTy =
-          dyn_cast_if_present<xegpu::TensorDescType>(value.getType())) {
-    auto layout = tdescTy.getLayoutAttr();
-    LLVM_DEBUG(DBGS() << "  -> TensorDescType, layout="
-                      << (layout ? layout : nullptr) << "\n");
-    return layout;
-  }
+          dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
+    return tdescTy.getLayoutAttr();
 
   if (auto result = dyn_cast<OpResult>(value)) {
     Operation *defOp = result.getDefiningOp();
     assert(defOp && "result must have a defining op");
-    LLVM_DEBUG(DBGS() << "  OpResult #" << result.getResultNumber() << " from "
-                      << defOp->getName() << "\n");
 
     if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
       auto layout = anchorOp.getAnchorLayout();
-      LLVM_DEBUG(DBGS() << "  -> AnchorLayoutInterface, layout="
-                        << (layout ? layout : nullptr) << "\n");
       return layout;
     }
 
@@ -181,100 +165,60 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
     if (defOp->hasAttr(layoutName)) {
       auto layout =
           defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-      LLVM_DEBUG(DBGS() << "  -> temporary attr '" << layoutName
-                        << "', layout=" << layout << "\n");
       return layout;
     }
-    LLVM_DEBUG(DBGS() << "  -> OpResult: no layout found (checked '"
-                      << layoutName << "')\n");
   }
 
   if (auto arg = dyn_cast<BlockArgument>(value)) {
     auto *parentOp = arg.getOwner()->getParentOp();
-    LLVM_DEBUG(DBGS() << "  BlockArgument #" << arg.getArgNumber() << " of "
-                      << (parentOp ? parentOp->getName().getStringRef()
-                                   : StringRef("(null)"))
-                      << "\n");
     if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
       OpOperand *tiedInit = loop.getTiedLoopInit(arg);
       if (tiedInit) {
-        LLVM_DEBUG(DBGS() << "  -> LoopLikeOp, recursing into tiedInit "
-                          << "operand #" << tiedInit->getOperandNumber()
-                          << "\n");
         return getDistributeLayoutAttr(tiedInit->get());
       }
-      LLVM_DEBUG(DBGS() << "  -> LoopLikeOp, no tiedInit\n");
     }
   }
 
-  LLVM_DEBUG(DBGS() << "  -> returning nullptr\n");
   return nullptr;
 }
 xegpu::DistributeLayoutAttr
 xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   Operation *op = opr.getOwner();
   unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
-  LLVM_DEBUG(DBGS() << "getDistributeLayoutAttr(OpOperand): operand #" << idx
-                    << " of " << op->getName()
-                    << ", type=" << opr.get().getType() << "\n");
 
   if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
     if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
       if (idx == 0) {
-        auto layout = dpasOp.getLayoutAAttr();
-        LLVM_DEBUG(DBGS() << "  -> DpasOp layoutA="
-                          << (layout ? layout : nullptr) << "\n");
-        return layout;
+        return dpasOp.getLayoutAAttr();
       } else if (idx == 1) {
-        auto layout = dpasOp.getLayoutBAttr();
-        LLVM_DEBUG(DBGS() << "  -> DpasOp layoutB="
-                          << (layout ? layout : nullptr) << "\n");
-        return layout;
+        return dpasOp.getLayoutBAttr();
       } else if (idx == 2) {
-        auto layout = dpasOp.getLayoutCdAttr();
-        LLVM_DEBUG(DBGS() << "  -> DpasOp layoutCd="
-                          << (layout ? layout : nullptr) << "\n");
-        return layout;
+        return dpasOp.getLayoutCdAttr();
       }
     }
     if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
-      auto layout = convertOp.getInputLayoutAttr();
-      LLVM_DEBUG(DBGS() << "  -> ConvertLayoutOp inputLayout="
-                        << (layout ? layout : nullptr) << "\n");
-      return layout;
+      return convertOp.getInputLayoutAttr();
     }
     auto layout = anchorOp.getAnchorLayout();
 
-    if (idx == 0) {
-      LLVM_DEBUG(DBGS() << "  -> AnchorLayoutInterface idx=0, layout="
-                        << (layout ? layout : nullptr) << "\n");
+    if (idx == 0)
       return layout;
-    }
 
     // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
     // the layout is valid for the first two operands: value and memref/tdesc.
     // For other operations, the layout applies to the first operand only.
     if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
             op) &&
-        (idx < 2)) {
-      LLVM_DEBUG(DBGS() << "  -> Store op idx=" << idx
-                        << ", layout=" << (layout ? layout : nullptr) << "\n");
+        (idx < 2))
       return layout;
-    }
-    LLVM_DEBUG(DBGS() << "  -> AnchorLayoutInterface idx=" << idx
-                      << " not covered, falling through\n");
   }
 
   std::string layoutName = xegpu::getTemporaryLayoutName(opr);
   if (op->hasAttr(layoutName)) {
     auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-    LLVM_DEBUG(DBGS() << "  -> temporary attr '" << layoutName
-                      << "', layout=" << layout << "\n");
     return layout;
   }
 
-  LLVM_DEBUG(DBGS() << "  -> returning nullptr (checked '" << layoutName
-                    << "')\n");
   return nullptr;
 }
 

>From 2617a0258e5435f141292f85c5633a8574bdaebd Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:41:17 +0000
Subject: [PATCH 07/21] cleanup

---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index bcac517937754..f0508a30621f2 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -173,9 +173,8 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
     auto *parentOp = arg.getOwner()->getParentOp();
     if (auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
       OpOperand *tiedInit = loop.getTiedLoopInit(arg);
-      if (tiedInit) {
+      if (tiedInit)
         return getDistributeLayoutAttr(tiedInit->get());
-      }
     }
   }
 

>From 182711fa0be85a820dfd3d3e769f27071762014d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 3 Apr 2026 20:51:43 +0000
Subject: [PATCH 08/21] change needed for recoverTemporaryLayout, only sg
 distribution tests fails

---
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |   5 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 457 +++++++++++++++---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |   2 +-
 .../XeGPU/sg-to-wi-experimental-unit.mlir     |   3 +-
 4 files changed, 410 insertions(+), 57 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 9cf9a8705209b..5f46eab7b74c7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -183,10 +183,13 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
                 VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
                 const uArch::uArch *uArch);
 
+DistributeLayoutAttr
+inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
+
 /// Gets the expected layout for a given consumer operand. This will check if
 /// the owning operation of the consumer operand is one of the special layout
 /// users and determine the expected layout accordingly.
-xegpu::DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
+DistributeLayoutAttr getConsumerLayoutAt(OpOperand &operand);
 
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 55cd6ec04970c..33c9086566d3c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -18,16 +18,22 @@
 #include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
 
+#define DEBUG_TYPE "xegpu-layout-recovery"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
 using namespace mlir;
 
 void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
@@ -80,32 +86,321 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
   return out;
 }
 
-// Attach layout attributes to all vector-type operands of operations within
-// the given operation's region. Reports an error if any vector operand lacks
-// a layout attribute.
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-  auto result = rootOp->walk([&](Operation *op) {
-    for (OpOperand &operand : op->getOpOperands()) {
-      // Layouts are needed for vector type only.
-      if (!isa<VectorType>(operand.get().getType()))
-        continue;
-      // Skip block arguments since they don't have defining ops to attach
-      // layout attributes to.
-      if (isa<BlockArgument>(operand.get()))
-        continue;
-      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+// Prerequisite for Layout Recovery
+// It relies on the following invariant:
+// 1. there is no layout conflict between different uses of the same definition.
+// 2. each definition has a well-defined layout requirement at its use point.
+//     - Every definition must have at least one use that appears after it in
+//     topological order.
+//     - If a definition has no such use (e.g., a loop result or region output),
+//     an explicit convert_layout operation is inserted to create a use.
+//     - Only the result of convert_layout is permitted to have no subsequent
+//     use.
+
+// The recovery proceeds by scanning the operation in reverse topological order
+// as follows:
+//    For regular operations: First the result layouts are propagated from uses.
+//      Then the result layouts are propagated to operands.
+//
+//    For region operations (e.g., loops):
+//       - When backward propagation reaches a region op, it sets the layout of
+//       the region op’s results according to use points like regular ops.
+//       - Then, the result layouts (such as a loop output) are propagated to
+//       their corresponding operands in the yield.
+//       - When backward propagation reaches the first operation inside the
+//       region, the pass examines the region op’s initialization list,
+//       propagating from region arguments to the corresponding initialization
+//       operands.
+//       - This ensures that layouts are consistently propagated
+//       across region boundaries while preserving a single well-defined use for
+//       each definition at the region-op level.
+
+// the inner function for recoverTemporaryLayouts is a recursive function
+// the input rootOp is the function operation, which is also a region op.
+// it recursivley process the region op in reverse topological order.
+
+static void walkRegionBackward(Region &region,
+                               llvm::function_ref<void(Operation *)> visit) {
+  // blocks: back -> front
+  for (Block &block : llvm::reverse(region)) {
+    // ops: back -> front, early-inc so visit() may erase current op safely
+    for (Operation &op : llvm::reverse(block)) {
+      // make sure we first visit inside the region op (so yield op first)
+      // and then move to region op itself
+      for (Region &nested : llvm::reverse(op.getRegions()))
+        walkRegionBackward(nested, visit);
+
+      visit(&op);
+    }
+  }
+}
+
+static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
+  xegpu::DistributeLayoutAttr layout = nullptr;
+  for (OpOperand &use : result.getUses()) {
+    if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
+      // debug print the use and op, and the tmpLayout
+      LLVM_DEBUG({
+        DBGS() << "getLayoutFromUsePoints  use: " << use.getOwner()->getName()
+               << use.getOwner();
+        llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
+      });
+      // under debug mode, we want to check all the use points to make sure
+      // there is no conflict, so we do not break here. In release mode, we can
+      // break at the first use
+      if (!layout)
+        layout = tmpLayout;
+    }
+  }
+  return layout;
+}
+
+// For regular operations: First the result layouts are propagated from uses.
+// Then the result layouts are propagated to uses (operands).
+static void propagateResultsToRegularOperands(Operation *op) {
+  LLVM_DEBUG(DBGS() << "propagateResultsToRegularOperands: " << op->getName()
+                    << " (" << op->getNumOperands() << " operands, "
+                    << op->getNumResults() << " results)\n");
+
+  if (op->getNumResults() == 0) {
+    LLVM_DEBUG(DBGS() << "  skipping (no results)\n");
+    return;
+  }
+
+  Value result = op->getResult(0);
+  xegpu::DistributeLayoutAttr resLayout =
+      getLayoutFromUsePoints(op->getResult(0));
+  Type resultType = result.getType();
+
+  // recover layout for tensor Descriptor type, which is a special case since
+  // its layout is not stored as an attribute but encoded in the type itself.
+  // For vector type, we attach the layout as an attribute to op.
+  if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
+    auto layout = tensorDescTy.getLayoutAttr();
+    // TODO: remove the layout check. The tensorDescType's layout is treated as
+    // temporary layout, which needs to be set by layout recovery.
+    // allow it now to pass some legacy test case
+    if (!layout) {
+      auto typeWithLayout = xegpu::TensorDescType::get(
+          tensorDescTy.getContext(), tensorDescTy.getShape(),
+          tensorDescTy.getElementType(), tensorDescTy.getEncoding(), resLayout);
+      result.setType(typeWithLayout);
+    }
+  }
+
+  for (OpOperand &opr : op->getOpOperands()) {
+    // Layouts are needed for vector type only.
+    xegpu::DistributeLayoutAttr operandLayout =
+        xegpu::inferSourceLayoutFromResult(opr, resLayout);
+    if (!isa<VectorType>(opr.get().getType())) {
+      LLVM_DEBUG(DBGS() << "  operand #" << opr.getOperandNumber()
+                        << ": skipped (non-vector type: " << opr.get().getType()
+                        << ")\n");
+      continue;
+    }
+
+    xegpu::setTemporaryLayout(opr, operandLayout);
+    // debug print op
+    LLVM_DEBUG(DBGS() << "after propagateResultsToRegularOperands  op: "
+                      << op->getName() << op << "  operand #"
+                      << opr.getOperandNumber()
+                      << ": type=" << opr.get().getType());
+    llvm::dbgs() << ", temp Layout=" << xegpu::getTemporaryLayout(opr);
+    llvm::dbgs() << "\n";
+  }
+}
+
+static void propagateRegionResultsToYieldOperands(
+    mlir::RegionBranchTerminatorOpInterface yieldOp) {
+  LLVM_DEBUG(DBGS() << "propagateRegionResultsToYieldOperands: "
+                    << yieldOp->getName() << " (" << yieldOp->getNumOperands()
+                    << " operands), parent="
+                    << yieldOp->getParentOp()->getName() << "\n");
+
+  if (isa<func::FuncOp>(yieldOp->getParentOp())) {
+    LLVM_DEBUG(DBGS() << "  skipping (parent is FuncOp)\n");
+    return;
+  }
+
+  auto regionBranchOp =
+      dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
+  if (!regionBranchOp) {
+    LLVM_DEBUG(DBGS() << "  skipping (parent is not RegionBranchOp)\n");
+    return;
+  }
+
+  // Gather layouts for each result of the parent region op from external
+  // use points.
+  unsigned numResults = regionBranchOp->getNumResults();
+  LLVM_DEBUG(DBGS() << "  parent op has " << numResults << " results\n");
+  if (numResults == 0)
+    return;
+
+  SmallVector<xegpu::DistributeLayoutAttr> resultLayouts(numResults, nullptr);
+  for (unsigned i = 0; i < numResults; ++i) {
+    OpResult result = regionBranchOp->getResult(i);
+    resultLayouts[i] = getLayoutFromUsePoints(result);
+    if (resultLayouts[i]) {
+      LLVM_DEBUG(DBGS() << "  result #" << i << ": type=" << result.getType()
+                        << ", layout=" << resultLayouts[i] << "\n");
+      xegpu::setTemporaryLayout(result, resultLayouts[i]);
+    } else {
+      LLVM_DEBUG(DBGS() << "  result #" << i
+                        << ": skipped (no layout from use points)\n");
+    }
+  }
+
+  // Use getSuccessorOperands to find which operands of the terminator
+  // flow to a successor. This handles index offsets automatically (e.g.,
+  // scf.condition's predicate at operand #0 is excluded).
+  // Pick the first successor to determine the operand range.
+  SmallVector<RegionSuccessor> successors;
+  SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
+  yieldOp.getSuccessorRegions(operandAttrs, successors);
+  assert(!successors.empty() && "terminator must have at least one successor");
+
+  OperandRange succOps = yieldOp.getSuccessorOperands(successors.front());
+  unsigned beginIdx = succOps.getBeginOperandIndex();
+  unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
+
+  LLVM_DEBUG(DBGS() << "  " << count << " successor operands starting at index "
+                    << beginIdx << "\n");
+
+  for (unsigned i = 0; i < count; ++i) {
+    if (!resultLayouts[i])
+      continue;
+    LLVM_DEBUG(DBGS() << "    -> setting layout on operand #" << (beginIdx + i)
+                      << "\n");
+    xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i),
+                              resultLayouts[i]);
+  }
+
+  LLVM_DEBUG({
+    DBGS() << " after propagateRegionResultsToYieldOperands:\n";
+    yieldOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+    llvm::dbgs() << "\n";
+  });
+}
+
+static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
+  LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
+                    << " (" << regionOp->getNumOperands() << " operands, "
+                    << regionOp->getNumRegions() << " regions)\n");
+  LLVM_DEBUG({
+    DBGS() << " before propagateRegionArgsToInits, Region IR:\n";
+    regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+    llvm::dbgs() << "\n";
+  });
+
+  // Iterate all regions of the region op. For each block argument that has a
+  // layout (determined from its use points), trace back to find the
+  // corresponding init operand of the regionOp and set the layout on it.
+  // This works generically for scf.for, scf.while, and other
+  // RegionBranchOpInterface ops.
+  for (Region &region : regionOp->getRegions()) {
+    RegionSuccessor regionSuccessor(&region);
+    for (auto [argIdx, regionArg] : llvm::enumerate(region.getArguments())) {
+      auto layout = getLayoutFromUsePoints(regionArg);
       if (!layout) {
-        op->emitWarning("Could not find layout attribute for operand ")
-            << operand.getOperandNumber() << " of operation " << op->getName();
+        LLVM_DEBUG(DBGS() << "  region #" << region.getRegionNumber()
+                          << " arg #" << argIdx << ": skipped (no layout)\n");
         continue;
       }
-      xegpu::setTemporaryLayout(operand, layout);
+      LLVM_DEBUG(DBGS() << "  region #" << region.getRegionNumber() << " arg #"
+                        << argIdx << ": type=" << regionArg.getType()
+                        << ", layout=" << layout << "\n");
+
+      // Find all predecessor values that flow into this block argument.
+      SmallVector<Value> predValues;
+      regionOp.getPredecessorValues(regionSuccessor, argIdx - 1, predValues);
+      for (Value predVal : predValues) {
+        // Match predecessor value to an operand of the regionOp.
+        for (OpOperand &operand : regionOp->getOpOperands()) {
+          if (operand.get() == predVal) {
+            LLVM_DEBUG(DBGS() << "    -> setting layout on init operand #"
+                              << operand.getOperandNumber() << "\n");
+            xegpu::setTemporaryLayout(operand, layout);
+          }
+        }
+      }
     }
-    return WalkResult::advance();
+  }
+
+  LLVM_DEBUG({
+    DBGS() << " after propagateRegionArgsToInits, Region IR:\n";
+    regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+    llvm::dbgs() << "\n";
+  });
+}
+
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+  LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts START ===\n");
+
+  auto processFunc = [&](Region &body, StringRef funcName) {
+    LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
+    walkRegionBackward(body, [&](Operation *op) {
+      LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
+      if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+        // hit the region op after visiting inside region
+        LLVM_DEBUG(DBGS() << "  -> dispatching as RegionBranchOp\n");
+        propagateRegionArgsToInits(regionOp);
+      } else if (auto yieldOp =
+                     dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
+        // yield op inside region op
+        LLVM_DEBUG(DBGS() << "  -> dispatching as YieldOp\n");
+        propagateRegionResultsToYieldOperands(yieldOp);
+      } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+        // if the op is regular op, calling propagateResultsToRegularOperands
+        LLVM_DEBUG(DBGS() << "  -> dispatching as regular op\n");
+        propagateResultsToRegularOperands(op);
+      }
+    });
+  };
+
+  rootOp->walk([&](func::FuncOp func) {
+    processFunc(func.getBody(), func.getSymName());
+  });
+  rootOp->walk([&](gpu::GPUFuncOp func) {
+    processFunc(func.getBody(), func.getName());
+  });
+
+  LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
+  // print the root op after
+  LLVM_DEBUG({
+    DBGS() << "After recoverTemporaryLayouts, IR:\n";
+    rootOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
+    llvm::dbgs() << "\n";
   });
-  return !result.wasInterrupted();
+  return true;
 }
 
+// // Attach layout attributes to all vector-type operands of operations within
+// // the given operation's region. Reports an error if any vector operand lacks
+// // a layout attribute.
+// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+//   auto result = rootOp->walk([&](Operation *op) {
+//     for (OpOperand &operand : op->getOpOperands()) {
+//       // Layouts are needed for vector type only.
+//       if (!isa<VectorType>(operand.get().getType()))
+//         continue;
+//       // Skip block arguments since they don't have defining ops to attach
+//       // layout attributes to.
+//       if (isa<BlockArgument>(operand.get()))
+//         continue;
+//       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+//       if (!layout) {
+//         op->emitWarning("Could not find layout attribute for operand ")
+//             << operand.getOperandNumber() << " of operation " <<
+//             op->getName();
+//         xegpu::setTemporaryLayout(operand, layout);
+//         continue;
+//       }
+//     }
+//     return WalkResult::advance();
+//   });
+//   return !result.wasInterrupted();
+// }
+
 template <typename T, typename>
 void xegpu::removeLayoutAttr(const T &operandOrResult) {
   Operation *owner = operandOrResult.getOwner();
@@ -1108,99 +1403,153 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
   return std::nullopt;
 }
 
-xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
+xegpu::DistributeLayoutAttr
+xegpu::inferSourceLayoutFromResult(OpOperand &operand,
+                                   xegpu::DistributeLayoutAttr resLayout) {
+  if (!resLayout) {
+    LLVM_DEBUG(DBGS() << "no resLayout, returning null\n");
+    return xegpu::DistributeLayoutAttr();
+  }
   Operation *op = operand.getOwner();
   unsigned idx = operand.getOperandNumber();
-  xegpu::DistributeLayoutAttr resLayout;
-  if (op->getNumResults() == 1)
-    resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
 
   // For vector::BroadcastOp, infer the source layout from the result layout.
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
-    if (!resLayout)
-      return xegpu::DistributeLayoutAttr();
+    LLVM_DEBUG(DBGS() << "  -> BroadcastOp\n");
     auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
-    if (!srcTy)
+    if (!srcTy) {
+      LLVM_DEBUG(DBGS() << "     source is not VectorType, returning null\n");
       return xegpu::DistributeLayoutAttr();
-    return xegpu::inferBroadcastSourceLayout(
+    }
+    auto inferred = xegpu::inferBroadcastSourceLayout(
         resLayout, broadcast.getResultVectorType().getShape(),
         srcTy.getShape());
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For vector::MultiDimReductionOp, infer source layout from result layout
   // using reduction dims. Acc operand is expected to have the same layout as
   // the result.
   if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
-    if (!resLayout)
-      return xegpu::DistributeLayoutAttr();
+    LLVM_DEBUG(DBGS() << "  -> MultiDimReductionOp, operand idx=" << idx
+                      << "\n");
     if (idx == 0) {
       SmallVector<int64_t> reductionDims(reduction.getReductionDims());
-      return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+      LLVM_DEBUG({
+        DBGS() << "     reductionDims=[";
+        llvm::interleaveComma(reductionDims, llvm::dbgs());
+        llvm::dbgs() << "]\n";
+      });
+      auto inferred =
+          xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+      LLVM_DEBUG(DBGS() << "     inferred source layout=" << inferred << "\n");
+      return inferred;
     }
-    if (idx == 1)
+    if (idx == 1) {
+      LLVM_DEBUG(DBGS() << "     acc operand, using resLayout\n");
       return resLayout;
+    }
   }
 
   if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
-    if (!resLayout)
-      return xegpu::DistributeLayoutAttr();
-    return xegpu::inferReductionSourceLayout(resLayout);
+    LLVM_DEBUG(DBGS() << "  -> ReductionOp\n");
+    auto inferred = xegpu::inferReductionSourceLayout(resLayout);
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For vector::BitCastOp, infer source layout from result layout using
   // element type bitwidths.
   if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
-    if (!resLayout)
-      return xegpu::DistributeLayoutAttr();
+    LLVM_DEBUG(DBGS() << "  -> BitCastOp\n");
     int resElemBitWidth =
         bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
     int srcElemBitWidth =
         bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
-    return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
-                                           srcElemBitWidth);
+    LLVM_DEBUG(DBGS() << "     resBitWidth=" << resElemBitWidth
+                      << ", srcBitWidth=" << srcElemBitWidth << "\n");
+    auto inferred = xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+                                                    srcElemBitWidth);
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For vector::ShapeCastOp, infer source layout from result layout using
   // shapes.
   if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
-    if (!resLayout)
-      return xegpu::DistributeLayoutAttr();
-    return xegpu::inferShapeCastSourceLayout(
+    LLVM_DEBUG({
+      DBGS() << "  -> ShapeCastOp: resShape=[";
+      llvm::interleaveComma(shapeCast.getResultVectorType().getShape(),
+                            llvm::dbgs());
+      llvm::dbgs() << "], srcShape=[";
+      llvm::interleaveComma(shapeCast.getSourceVectorType().getShape(),
+                            llvm::dbgs());
+      llvm::dbgs() << "]\n";
+    });
+    auto inferred = xegpu::inferShapeCastSourceLayout(
         resLayout, shapeCast.getResultVectorType().getShape(),
         shapeCast.getSourceVectorType().getShape());
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For vector::InsertStridedSliceOp, infer source layout from result layout.
   // Dest vector must have the same layout as the result.
   if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
-    if (!resLayout)
-      return xegpu::DistributeLayoutAttr();
-    if (idx == 0)
-      return xegpu::inferInsertStridedSliceSourceLayout(
+    LLVM_DEBUG(DBGS() << "  -> InsertStridedSliceOp, operand idx=" << idx
+                      << "\n");
+    if (idx == 0) {
+      auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
           resLayout, insertSlice.getDestVectorType().getShape(),
           insertSlice.getSourceVectorType().getShape());
-    if (idx == 1)
+      LLVM_DEBUG(DBGS() << "     inferred source layout=" << inferred << "\n");
+      return inferred;
+    }
+    if (idx == 1) {
+      LLVM_DEBUG(DBGS() << "     dest operand, using resLayout\n");
       return resLayout;
+    }
   }
 
   // For vector::TransposeOp, infer source layout from result layout using
   // permutation.
   if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
-    if (!resLayout)
-      return xegpu::DistributeLayoutAttr();
-    return xegpu::inferTransposeSourceLayout(resLayout,
-                                             transpose.getPermutation());
+    LLVM_DEBUG({
+      DBGS() << "  -> TransposeOp, perm=[";
+      llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
+      llvm::dbgs() << "]\n";
+    });
+    auto inferred = xegpu::inferTransposeSourceLayout(
+        resLayout, transpose.getPermutation());
+    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
+    return inferred;
   }
 
   // For elementwise operations, all operands must have the same layout as the
   // result.
   if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
-    if (!resLayout)
-      return xegpu::DistributeLayoutAttr();
+    LLVM_DEBUG(DBGS() << "  -> elementwise op, using resLayout="
+                      << (resLayout ? resLayout : nullptr) << "\n");
+
     return resLayout;
   }
-  // TODO: Handle more cases as needed here.
+  return xegpu::DistributeLayoutAttr();
+}
+
+xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
+  Operation *op = operand.getOwner();
+  xegpu::DistributeLayoutAttr resLayout;
+  if (op->getNumResults() == 1)
+    resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+  auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
+  if (inferredOperandLayout)
+    return inferredOperandLayout;
   // By default, assume no layout conflict and return the current layout of
   // the operand.
-  return xegpu::getDistributeLayoutAttr(operand.get());
+  auto fallback = xegpu::getDistributeLayoutAttr(operand.get());
+  LLVM_DEBUG(DBGS() << "  -> fallback (unhandled op " << op->getName()
+                    << "), returning operand layout="
+                    << (fallback ? fallback : nullptr) << "\n");
+  return fallback;
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4c30dacae8850..f0ff771f4cbc4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1338,7 +1338,7 @@ LogicalResult ResolveLayoutConflicts::run() {
     // as anchor op for the reduction op's layout.
     if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
       for (OpResult result : op->getResults()) {
-        if (result.getType().isIntOrFloat()) {
+        if (result.getType().isIntOrFloat() || result.use_empty()) {
           auto res = assignResultLayout(result);
           if (failed(res)) {
             DBGS() << "Failed to resolve vector consumer for multi-reduction "
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 842c2375dd31d..ae0347d507e9c 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -1092,7 +1092,8 @@ gpu.module @xevm_module {
 gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) {
   %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x1xf16>
   %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
-  "some_use"(%1) : (vector<16x16xf16>) -> ()
+  %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>
+  "some_use"(%2) : (vector<16x16xf16>) -> ()
   gpu.return
 }
 }

>From 6d4955dd6cc75162045db403e7822c07b5c4976e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 8 Apr 2026 05:29:22 +0000
Subject: [PATCH 09/21] almost there - only sg-to-wi-experimental-unit.mlir
 fails

---
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |  6 ++
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  2 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 18 ++++-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 70 +++++++++++--------
 .../XeGPUSgToWiDistributeExperimental.cpp     | 10 +++
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 50 +++++++++++--
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  3 -
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 25 +++++--
 .../XeGPU/sg-to-wi-experimental-unit.mlir     |  5 ++
 .../Dialect/XeGPU/sg-to-wi-experimental.mlir  | 30 +++-----
 .../XeGPU/subgroup-distribute-unit.mlir       | 69 +++++-------------
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 19 ++---
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  2 +-
 13 files changed, 178 insertions(+), 131 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 5f46eab7b74c7..c36201c2f0d9e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -112,6 +112,12 @@ inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
                                     ArrayRef<int64_t> resShape,
                                     ArrayRef<int64_t> srcShape);
 
+/// Infers the layout attribute for mask and offset operand for Chunked load
+/// and store, given the anchor layout attribute for the value being load/store.
+DistributeLayoutAttr
+inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout,
+                                  int chunkSize);
+
 /// Sets up layout for Multi-Reduction operations by creating a SliceAttr for
 /// the result.
 ///
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 64c56b5adf5d7..eaa43c02946d8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -585,7 +585,7 @@ DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
     int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
     newOrder.push_back(d - offset);
   }
-  if (sgLayout.empty() && laneLayout.empty())
+  if ((sgLayout.empty() && laneLayout.empty()) || newOrder.size() == 1)
     newOrder.clear();
 
   auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 29af3d3f1d95a..d5c5e05549745 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -167,9 +167,8 @@ static void propagateResultsToRegularOperands(Operation *op) {
     return;
   }
 
-  Value result = op->getResult(0);
-  xegpu::DistributeLayoutAttr resLayout =
-      getLayoutFromUsePoints(op->getResult(0));
+  OpResult result = op->getResult(0);
+  xegpu::DistributeLayoutAttr resLayout = getLayoutFromUsePoints(result);
   Type resultType = result.getType();
 
   // recover layout for tensor Descriptor type, which is a special case since
@@ -188,6 +187,8 @@ static void propagateResultsToRegularOperands(Operation *op) {
     }
   }
 
+  xegpu::setTemporaryLayout(result, resLayout);
+
   for (OpOperand &opr : op->getOpOperands()) {
     // Layouts are needed for vector type only.
     xegpu::DistributeLayoutAttr operandLayout =
@@ -676,6 +677,17 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   return nullptr;
 }
 
+/// Infers the layout attribute for mask and offset operand for Chunked load
+/// and store, given the anchor layout attribute for the value being load/store.
+xegpu::DistributeLayoutAttr xegpu::inferMaskOffsetLayoutForScatterIO(
+    xegpu::DistributeLayoutAttr payloadLayout, int chunkSize) {
+  auto rank = payloadLayout.getRank();
+  if (chunkSize > 1)
+    return payloadLayout.dropDims(
+        llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
+  return payloadLayout;
+}
+
 /// Sets up layout for reduction operations by creating a SliceAttr for the
 /// result.
 ///
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index f0ff771f4cbc4..8abb1fc4bae52 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1027,7 +1027,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
   const uArch *uArch = getUArch(getChipStr(load).value_or(""));
   if (!uArch)
     return;
-  auto subgroupSize = uArch->getSubgroupSize();
+  // auto subgroupSize = uArch->getSubgroupSize();
   VectorType resVecTy = load.getValueType();
   int chunkSize = load.getChunkSize().value_or(1);
 
@@ -1049,20 +1049,24 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     load.setLayoutAttr(requiredAnchorLayoutAttr);
   }
 
-  auto maskLayoutAttr = requiredAnchorLayoutAttr;
-  // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
-  // layout for mask.
-  if (chunkSize > 1) {
-    if (layoutKind == xegpu::LayoutKind::InstData)
-      maskLayoutAttr =
-          xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
-    else if (layoutKind == xegpu::LayoutKind::Lane)
-      maskLayoutAttr =
-          xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
-    else
-      assert(false &&
-             "chunked StoreScatterOp should not be used at workgroup level");
-  }
+  assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
+  auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
+      requiredAnchorLayoutAttr, chunkSize);
+
+  // // Special handling mask layout for chunked ops: Enforce the default xegpu
+  // 1D
+  // // layout for mask.
+  // if (chunkSize > 1) {
+  //   if (layoutKind == xegpu::LayoutKind::InstData)
+  //     maskLayoutAttr =
+  //         xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
+  //   else if (layoutKind == xegpu::LayoutKind::Lane)
+  //     maskLayoutAttr =
+  //         xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
+  //   else
+  //     assert(false &&
+  //            "chunked StoreScatterOp should not be used at workgroup level");
+  // }
 
   LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
   auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
@@ -1105,7 +1109,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   const uArch *uArch = getUArch(getChipStr(storeScatter).value_or(""));
   if (!uArch)
     return;
-  auto subgroupSize = uArch->getSubgroupSize();
+  // auto subgroupSize = uArch->getSubgroupSize();
   VectorType srcVecTy = storeScatter.getValueType();
   int chunkSize = storeScatter.getChunkSize().value_or(1);
 
@@ -1122,23 +1126,27 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   }
 
   LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
-  auto maskLayoutAttr = requiredAnchorLayoutAttr;
-  // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
-  // layout for mask.
-  if (chunkSize > 1) {
-    if (layoutKind == xegpu::LayoutKind::InstData)
-      maskLayoutAttr =
-          xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
-    else if (layoutKind == xegpu::LayoutKind::Lane)
-      maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
-                                              {subgroupSize}, {1});
-    else
-      assert(false &&
-             "chunked StoreScatterOp should not be used at workgroup level");
-  }
-
+  assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
+  auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
+      requiredAnchorLayoutAttr, chunkSize);
   LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
 
+  // auto maskLayoutAttr = requiredAnchorLayoutAttr;
+  // // Special handling mask layout for chunked ops: Enforce the default xegpu
+  // 1D
+  // // layout for mask.
+  // if (chunkSize > 1) {
+  //   if (layoutKind == xegpu::LayoutKind::InstData)
+  //     maskLayoutAttr =
+  //         xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
+  //   else if (layoutKind == xegpu::LayoutKind::Lane)
+  //     maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
+  //                                             {subgroupSize}, {1});
+  //   else
+  //     assert(false &&
+  //            "chunked StoreScatterOp should not be used at workgroup level");
+  // }
+
   // Propagate the payload operand layout
   propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
   // Propagate the destination (if tdesc) operand layout
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index e3227c7f5b149..c029ee1d8ae0d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -1610,6 +1610,16 @@ void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
       }
     });
   }
+  // Remove layout attributes from SCF ops
+  getOperation()->walk([](Operation *op) {
+    SmallVector<StringAttr> attrsToRemove;
+    for (auto namedAttr : op->getDiscardableAttrs()) {
+      if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
+        attrsToRemove.push_back(namedAttr.getName());
+    }
+    for (auto attrName : attrsToRemove)
+      op->removeDiscardableAttr(attrName);
+  });
 }
 
 void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d8ce24ddd5cb0..012d7aefafb06 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -825,12 +825,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
       }
     }
 
-    auto layoutPayload =
-        xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(0));
+    auto layoutPayload = storeScatterOp.getLayoutAttr();
     auto layoutOffsets =
-        xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
-    auto layoutMask =
-        xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
+        xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
+    auto layoutMask = layoutOffsets;
 
     FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
@@ -1132,9 +1130,36 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
       }
     }
 
+    auto layoutPayload = loadGatherOp.getLayoutAttr();
     auto layoutOffsets =
-        xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
-    auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
+        xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
+    auto layoutMask = layoutOffsets;
+
+    // print the layouts for debug
+    LLVM_DEBUG({
+      llvm::dbgs() << "In LoadDistribution pattern:\n";
+      llvm::dbgs() << "Payload layout: ";
+      if (layoutPayload)
+        llvm::dbgs() << layoutPayload;
+      else
+        llvm::dbgs() << "none";
+      llvm::dbgs() << "\nOffsets layout: ";
+      if (layoutOffsets)
+        llvm::dbgs() << layoutOffsets;
+      else
+        llvm::dbgs() << "none";
+      llvm::dbgs() << "\nMask layout: ";
+      if (layoutMask)
+        llvm::dbgs() << layoutMask;
+      else
+        llvm::dbgs() << "none";
+      llvm::dbgs() << "\n";
+    });
+
+    // auto layoutOffsets =
+    //     xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
+    // auto layoutMask =
+    // xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
 
     FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
@@ -2281,4 +2306,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       op->erase();
     return WalkResult::advance();
   });
+
+  // Remove layout attributes from SCF ops
+  getOperation()->walk([](Operation *op) {
+    SmallVector<StringAttr> attrsToRemove;
+    for (auto namedAttr : op->getDiscardableAttrs()) {
+      if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
+        attrsToRemove.push_back(namedAttr.getName());
+    }
+    for (auto attrName : attrsToRemove)
+      op->removeDiscardableAttr(attrName);
+  });
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a095c19d66c15..02f88828f667f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1786,9 +1786,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
   // Remove layout attributes from SCF ops
   getOperation()->walk([](Operation *op) {
-    if (!isa<RegionBranchOpInterface, RegionBranchTerminatorOpInterface>(op))
-      return;
-
     SmallVector<StringAttr> attrsToRemove;
     for (auto namedAttr : op->getDiscardableAttrs()) {
       if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index f0508a30621f2..0dab06d206ceb 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
@@ -203,13 +204,27 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
     if (idx == 0)
       return layout;
 
-    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+    // For StoreNdOp and StoreMatrixOp,
     // the layout is valid for the first two operands: value and memref/tdesc.
-    // For other operations, the layout applies to the first operand only.
-    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
-            op) &&
-        (idx < 2))
+    if (isa<xegpu::StoreNdOp, xegpu::StoreMatrixOp>(op) && (idx < 2))
+      return layout;
+
+    if (isa<xegpu::StoreScatterOp>(op)) {
+      xegpu::StoreScatterOp store(op);
+      int chunkSize = store.getChunkSize().value_or(1);
+      if (layout && idx >= 2 && chunkSize > 1)
+        return layout.dropDims(llvm::to_vector(
+            llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
+      return layout;
+    }
+    if (isa<xegpu::LoadGatherOp>(op)) {
+      xegpu::LoadGatherOp load(op);
+      int chunkSize = load.getChunkSize().value_or(1);
+      if (layout && idx >= 1 && chunkSize > 1)
+        return layout.dropDims(llvm::to_vector(
+            llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
       return layout;
+    }
   }
 
   std::string layoutName = xegpu::getTemporaryLayoutName(opr);
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index c7b5542f91cbd..d53d0463db4cb 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -487,6 +487,11 @@ gpu.func @vector_transpose() {
       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
     : vector<16x2xf32> to vector<2x16xf32>
+  %transpose2 = xegpu.convert_layout %transpose
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<2x16xf32>
   gpu.return
 }
 
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
index 9febd79c7adc3..723c70a09931e 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental.mlir
@@ -235,28 +235,24 @@ gpu.func @gemm_with_postop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x102
 gpu.module @xevm_module{
   gpu.func @load_dpas_postop_store(%arg0: memref<8x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<8x16xf32>) {
     %c0 = arith.constant 0 : index
-    %cst = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.0> : vector<8x16xf32>
+    %cst = arith.constant dense<0.0> : vector<8x16xf32>
     %0 = xegpu.create_nd_tdesc %arg0 : memref<8x16xf16>
       -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %1 = xegpu.load_nd %0[%c0, %c0]
-      {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+      {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
       !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf16>
     %2 = xegpu.create_nd_tdesc %arg1: memref<16x16xf16>
       -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
     %3 = xegpu.load_nd %2[%c0, %c0]
-      {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
-       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+      {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
       : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
       -> vector<16x16xf16>
     %4 = xegpu.dpas %1, %3, %cst
       {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
        layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
-       layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+       layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
       : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
     %5 = math.exp %4
-      {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
       : vector<8x16xf32>
     %6 = xegpu.create_nd_tdesc %arg2 : memref<8x16xf32> ->
       !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -279,7 +275,7 @@ gpu.module @xevm_module{
 // CHECK-NEXT:        scf.yield %[[LD_CAST]] : vector<1x8xf16>
 // CHECK-NEXT:      } else {
 // CHECK-NEXT:        scf.yield %[[CST]] : vector<1x8xf16>
-// CHECK-NEXT:      } {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
+// CHECK-NEXT:      }
 // CHECK-NEXT:      %[[IF_CAST:.*]] = vector.shape_cast %[[IF]] : vector<1x8xf16> to vector<8xf16>
 // CHECK-NEXT:      xegpu.store %[[IF_CAST]], %{{.*}}[%[[OFFSET]]], %[[MASK]] <{chunk_size = 8 : i64}>
 // CHECK-SAME:        vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
@@ -289,16 +285,13 @@ gpu.module @xevm_module{
     %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
     %loaded = scf.if %pred -> (vector<16x8xf16>) {
       %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
-        layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
-        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+        layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
       } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
       scf.yield %3 : vector<16x8xf16>
     } else {
-      %3 = arith.constant {
-        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
-      } dense<12.> : vector<16x8xf16>
+      %3 = arith.constant dense<12.> : vector<16x8xf16>
       scf.yield %3 : vector<16x8xf16>
-    } { layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]> }
+    }
     xegpu.store %loaded, %src[%offset], %1 <{chunk_size=8}> {layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
@@ -318,12 +311,11 @@ gpu.module @xevm_module{
 gpu.module @xevm_module{
   gpu.func @scatter_ops_scf_non_yield(%src: memref<256xf16>) {
     %pred = llvm.mlir.poison : i1
-    %1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1>: vector<16xi1>
-    %offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
+    %1 = arith.constant dense<1>: vector<16xi1>
+    %offset = arith.constant dense<12> : vector<16xindex>
     scf.if %pred  {
       %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}> {
-        layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
-        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
+        layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
       } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
       xegpu.store %3, %src[%offset], %1 <{chunk_size=8}> {layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
     }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 18ebc09caa5aa..27c5bd497b948 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -469,10 +469,9 @@ gpu.func @vector_multi_reduction_3d_trivial_reduction(%laneid: index) {
   gpu.return
 }
 
-
 // CHECK-LABEL: gpu.func @scatter_ops_chunksize({{.*}}) {
-// CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
-// CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
+// CHECK:       %[[OFFSETS:.*]] = arith.constant dense<12> : vector<16xindex>
+// CHECK:       %[[MASKS:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
 // CHECK-SAME:    -> (vector<1x8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>) {
 // CHECK:         gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]] :
@@ -484,34 +483,19 @@ gpu.func @vector_multi_reduction_3d_trivial_reduction(%laneid: index) {
 // CHECK-SAME:    : vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
   gpu.warp_execute_on_lane_0(%laneid)[16] {
-    %1 = arith.constant
-      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-      dense<1>: vector<16xi1>
-    %offset = arith.constant
-      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-      dense<12> : vector<16xindex>
-    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8}>
-      {
-        layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>
-      }
+    %1 = arith.constant dense<1>: vector<16xi1>
+    %offset = arith.constant dense<12> : vector<16xindex>
+    %3 = xegpu.load %src[%offset], %1 <{chunk_size=8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
       : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-    xegpu.store %3, %src[%offset], %1 <{chunk_size=8}>
-      {
-        layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>,
-        layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-        layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-      }
+    xegpu.store %3, %src[%offset], %1 <{chunk_size=8, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
       : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
   }
   gpu.return
 }
 
-
 // CHECK-LABEL: gpu.func @scatter_ops({{.*}}) {
-// CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<16xindex>
-// CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<16xi1>
+// CHECK:       %[[OFFSETS:.*]] = arith.constant dense<12> : vector<16xindex>
+// CHECK:       %[[MASKS:.*]] = arith.constant dense<true> : vector<16xi1>
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
 // CHECK-SAME:    -> (vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>) {
 // CHECK:         gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]]
@@ -523,23 +507,15 @@ gpu.func @scatter_ops_chunksize(%laneid: index, %src: memref<256xf16>) {
 // CHECK-SAME:    : vector<1xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
 gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
   gpu.warp_execute_on_lane_0(%laneid)[16] {
-    %1 = arith.constant
-      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-      dense<1> : vector<16xi1>
-    %offset = arith.constant
-      {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-      dense<12> : vector<16xindex>
+    %1 = arith.constant dense<1> : vector<16xi1>
+    %offset = arith.constant dense<12> : vector<16xindex>
     %3 = xegpu.load %src[%offset], %1
     {
-      layout_operand_1 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-      layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
     } : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
     xegpu.store %3, %src[%offset], %1
     {
-      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-      layout_operand_2 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
-      layout_operand_3 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+      layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
     }
     : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
   }
@@ -547,8 +523,8 @@ gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
 }
 
 // CHECK-LABEL: gpu.func @scatter_ops_with_leading_dims({{.*}}) {
-// CHECK:       %[[OFFSETS:.*]] = arith.constant {{.*}} dense<12> : vector<1x1x16xindex>
-// CHECK:       %[[MASKS:.*]] = arith.constant {{.*}} dense<true> : vector<1x1x16xi1>
+// CHECK:       %[[OFFSETS:.*]] = arith.constant dense<12> : vector<1x1x16xindex>
+// CHECK:       %[[MASKS:.*]] = arith.constant dense<true> : vector<1x1x16xi1>
 // CHECK:       %[[W:.*]]:4 = gpu.warp_execute_on_lane_0(%{{.*}})[16]
 // CHECK-SAME:    -> (vector<1x1x1xf16>, memref<256xf16>, vector<1x1x1xindex>, vector<1x1x1xi1>) {
 // CHECK:         gpu.yield %{{.*}}, %{{.*}}, %[[OFFSETS]], %[[MASKS]]
@@ -563,23 +539,12 @@ gpu.func @scatter_ops(%src: memref<256xf16>, %laneid: index) {
 gpu.func @scatter_ops_with_leading_dims(%src: memref<256xf16>, %laneid: index) {
   gpu.warp_execute_on_lane_0(%laneid)[16] {
     %1 = arith.constant
-      {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
       dense<1> : vector<1x1x16xi1>
     %offset = arith.constant
-      {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
       dense<12> : vector<1x1x16xindex>
-    %3 = xegpu.load %src[%offset], %1
-    {
-      layout_operand_1 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
-      layout_operand_2 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
-      layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
-    } : memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
-    xegpu.store %3, %src[%offset], %1
-    {
-      layout_operand_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
-      layout_operand_2 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>,
-      layout_operand_3 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
-    }
+    %3 = xegpu.load %src[%offset], %1 {layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>} 
+    : memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+    xegpu.store %3, %src[%offset], %1 { layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
     : vector<1x1x16xf16>, memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
   }
   gpu.return
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index c3cdc79d9f70e..285669cae7174 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -10,7 +10,7 @@
 // CHECK:         %[[T1:.*]] = xegpu.load_nd %[[T0]][%{{.*}}] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
 // CHECK-DAG:     %[[T4:.*]] = xegpu.dpas %[[T3]], %[[T1]] : vector<8xf16>, vector<16xf16> -> vector<8xf32>
 // CHECK:         %[[T5:.*]] = vector.shape_cast %[[T4]] : vector<8xf32> to vector<8x1xf32>
-// CHECK:         %[[T6:.*]] = math.exp %[[T5]] {{{.*}}} : vector<8x1xf32>
+// CHECK:         %[[T6:.*]] = math.exp %[[T5]] : vector<8x1xf32>
 // CHECK-DAG:     %[[T8:.*]] = vector.shape_cast %[[T6]] : vector<8x1xf32> to vector<8xf32>
 // CHECK-DAG:     %[[T7:.*]] = xegpu.create_nd_tdesc %[[ARG2]] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:         xegpu.store_nd %[[T8]], %[[T7]][{{.*}}] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
@@ -423,16 +423,17 @@ gpu.module @xevm_test {
     // CHECK: %[[VEC_RED:.*]] = vector.broadcast %[[VEC_RED_3]] : f32 to vector<1xf32>
     // CHECK: xegpu.store %[[VEC_RED]], %arg1[%[[CST]]], %[[CST_0]] : vector<1xf32>, memref<256xf32>, vector<1xindex>, vector<1xi1>
   gpu.func @vector_reduce_2d(%arg0: memref<4x16xf32>, %arg1: memref<256xf32>) {
-      %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>} 1.000000e+00 : f32
+      %cst = arith.constant 1.000000e+00 : f32
       %0 = xegpu.create_nd_tdesc %arg0 : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
       %1 = xegpu.load_nd %0[0, 0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
-      %2 = vector.broadcast %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : f32 to vector<16xf32>
-      %3 = vector.multi_reduction <add>, %1, %2 {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<4x16xf32> to vector<16xf32>
-      %4 = vector.reduction <add>, %3 {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>} : vector<16xf32> into f32
-      %5 = vector.broadcast %4 {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : f32 to vector<16xf32>
-      %cst_0 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<16xindex>
-      %cst_1 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-      xegpu.store %5, %arg1[%cst_0], %cst_1 <{layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}> : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
+      %2 = vector.broadcast %cst : f32 to vector<16xf32>
+      %3 = vector.multi_reduction <add>, %1, %2 [0] : vector<4x16xf32> to vector<16xf32>
+      %4 = vector.reduction <add>, %3 : vector<16xf32> into f32
+      %40 = xegpu.convert_layout %4 <{input_layout =  #xegpu.slice<#xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, dims = [0]>, target_layout =  #xegpu.slice<#xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, dims = [0]>}>: f32
+      %5 = vector.broadcast %40 : f32 to vector<16xf32>
+      %cst_0 = arith.constant dense<0> : vector<16xindex>
+      %cst_1 = arith.constant dense<true> : vector<16xi1>
+      xegpu.store %5, %arg1[%cst_0], %cst_1 <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
     gpu.return
   }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index bbdffa0986962..3bc43b780ade2 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -405,7 +405,7 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: gpu.func @vector_reduce_scalar_cross_sg
   // CHECK-SAME: (%[[ARG0:.*]]: memref<32x32xf32>)
-  // CHECK-DAG: %[[CST:.*]] = arith.constant {{.*}} 0.000000e+00 : f32
+  // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
   // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x8xf32> -> vector<8x8xf32>
   // CHECK-DAG: %[[CST_ACC:.*]] = arith.constant 0.000000e+00 : f32
   // CHECK-DAG: %[[LOCAL:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_ACC]] [0, 1] : vector<8x8xf32> to f32

>From 01602b9a946cdded77cc3d09b2033ecfe495e615 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 8 Apr 2026 20:51:46 +0000
Subject: [PATCH 10/21] fix tests

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  22 +--
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |  12 +-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  51 ++++-
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 175 ++++++++++++++++++
 4 files changed, 237 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 5697097a4c999..d6981052a7a5c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1101,17 +1101,17 @@ LogicalResult ConvertLayoutOp::verify() {
     return emitOpError("expected input layout and target layout be WgLayout or "
                        "SgLayout at the same time.");
 
-  Type srcType = getSource().getType();
-  if (llvm::isa<VectorType>(srcType)) {
-    auto shape = llvm::cast<VectorType>(srcType).getShape();
-    if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
-      return emitOpError(
-          "invalid input layout, data cannot be evenly distributed.");
-
-    if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
-      return emitOpError(
-          "invalid target layout, data cannot be evenly distributed.");
-  }
+  // Type srcType = getSource().getType();
+  // if (llvm::isa<VectorType>(srcType)) {
+  //   auto shape = llvm::cast<VectorType>(srcType).getShape();
+  //   if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
+  //     return emitOpError(
+  //         "invalid input layout, data cannot be evenly distributed.");
+
+  //   if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
+  //     return emitOpError(
+  //         "invalid target layout, data cannot be evenly distributed.");
+  // }
   return mlir::success();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index d5c5e05549745..47c60eaf7d4e0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1541,12 +1541,14 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
     return inferred;
   }
 
-  // For elementwise operations, all operands must have the same layout as the
-  // result.
-  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
-    LLVM_DEBUG(DBGS() << "  -> elementwise op, using resLayout="
+  if (isa<VectorType>(operand.get().getType()) &&
+      !dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+    // For elementwise operations, all operands must have the same layout as the
+    // result.
+    // if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() ==
+    // 1) {
+    LLVM_DEBUG(DBGS() << "  -> other vector or tensorDesc ops using resLayout="
                       << (resLayout ? resLayout : nullptr) << "\n");
-
     return resLayout;
   }
   return xegpu::DistributeLayoutAttr();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8abb1fc4bae52..6aeabacf4b30f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1538,37 +1538,63 @@ static LogicalResult
 updateControlFlowOps(mlir::OpBuilder &builder,
                      mlir::RegionBranchTerminatorOpInterface terminator,
                      GetLayoutFnTy getLayoutOfValue) {
+  LLVM_DEBUG(DBGS() << "updateControlFlowOps: processing terminator: "
+                    << *terminator << "\n");
   // Only process if the terminator is inside a region branch op.
   auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
-  if (!branchOp)
+  if (!branchOp) {
+    LLVM_DEBUG(
+        DBGS() << "  parent is not a RegionBranchOpInterface, skipping\n");
     return success();
+  }
+  LLVM_DEBUG(DBGS() << "  parent branch op: " << *branchOp << "\n");
 
   RegionBranchSuccessorMapping mapping;
   branchOp.getSuccessorOperandInputMapping(mapping,
                                            RegionBranchPoint(terminator));
+  LLVM_DEBUG(DBGS() << "  successor mapping has " << mapping.size()
+                    << " entries\n");
   for (const auto &[successorOperand, successorInputs] : mapping) {
+    LLVM_DEBUG(DBGS() << "  processing successor operand: "
+                      << successorOperand->get()
+                      << " (type: " << successorOperand->get().getType()
+                      << "), num successor inputs: " << successorInputs.size()
+                      << "\n");
     for (Value successorInput : successorInputs) {
       Type inputType = successorInput.getType();
+      LLVM_DEBUG(DBGS() << "    successor input: " << successorInput
+                        << ", type: " << inputType << "\n");
       // We only need to operate on tensor descriptor or vector types.
-      if (!isa<xegpu::TensorDescType, VectorType>(inputType))
+      if (!isa<xegpu::TensorDescType, VectorType>(inputType)) {
+        LLVM_DEBUG(
+            DBGS() << "    skipping: not a TensorDescType or VectorType\n");
         continue;
+      }
       xegpu::DistributeLayoutAttr successorInputLayout =
           getLayoutOfValue(successorInput);
       xegpu::DistributeLayoutAttr successorOperandLayout =
           getLayoutOfValue(successorOperand->get());
 
+      LLVM_DEBUG(DBGS() << "    successor input layout: ");
+      LLVM_DEBUG(if (successorInputLayout) llvm::dbgs() << successorInputLayout;
+                 else llvm::dbgs() << "<<NULL>>"; llvm::dbgs() << "\n");
+      LLVM_DEBUG(DBGS() << "    successor operand layout: ");
+      LLVM_DEBUG(if (successorOperandLayout) llvm::dbgs()
+                     << successorOperandLayout;
+                 else llvm::dbgs() << "<<NULL>>"; llvm::dbgs() << "\n");
+
       // If either of the layouts is not assigned, we cannot proceed.
       if (!successorOperandLayout) {
-        LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
-                             "branch terminator: "
+        LLVM_DEBUG(DBGS() << "    FAILURE: No layout assigned for forwarded "
+                             "operand in branch terminator: "
                           << successorOperand->get() << "\n");
         return failure();
       }
       // We expect the layouts to match.
       if (successorInputLayout &&
           successorInputLayout != successorOperandLayout) {
-        LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
-                             "operand forwarded as the argument: "
+        LLVM_DEBUG(DBGS() << "    FAILURE: Conflicting layouts for region "
+                             "argument and operand forwarded as the argument: "
                           << successorInputLayout << " vs "
                           << successorOperandLayout << "\n");
         return failure();
@@ -1578,15 +1604,26 @@ updateControlFlowOps(mlir::OpBuilder &builder,
         auto newTdescTy = xegpu::TensorDescType::get(
             tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
             tdescTy.getEncoding(), successorOperandLayout);
+        LLVM_DEBUG(DBGS() << "    updating tensor desc type: " << tdescTy
+                          << " -> " << newTdescTy << "\n");
         successorInput.setType(newTdescTy);
         continue;
       }
       // If the type is a vector type and this region argument is an OpResult,
       // set the layout attribute on the OpResult.
-      if (auto result = dyn_cast<OpResult>(successorInput))
+      if (auto result = dyn_cast<OpResult>(successorInput)) {
+        LLVM_DEBUG(DBGS() << "    setting layout on OpResult #"
+                          << result.getResultNumber() << " of "
+                          << *result.getOwner() << " to "
+                          << successorOperandLayout << "\n");
         xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
+      } else {
+        LLVM_DEBUG(DBGS() << "    successor input is a BlockArgument, "
+                             "not setting layout attribute\n");
+      }
     }
   }
+  LLVM_DEBUG(DBGS() << "  updateControlFlowOps: success\n");
   return success();
 }
 
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index d53d0463db4cb..056e5e7e34a63 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -136,6 +136,11 @@ gpu.func @elementwise() {
   %3 = arith.addf %0, %2
     {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     : vector<16x16xf32>
+  %cl3 = xegpu.convert_layout %3
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<16x16xf32>
   gpu.return
 }
 
@@ -145,6 +150,11 @@ gpu.func @elementwise() {
 gpu.func @arith_constant() {
   %0 = arith.constant  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     dense<1.0> : vector<16x16xf32>
+  %cl0 = xegpu.convert_layout %0
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<16x16xf32>
   gpu.return
 }
 
@@ -363,6 +373,11 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim1_reduction(%laneid: index)
         layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>
       }
       [1] : vector<2x16xf32> to vector<2xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>
+    }> : vector<2xf32>
   gpu.return
 }
 
@@ -428,6 +443,11 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim0_reduction(%laneid: index)
         layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
       }
       [0] : vector<16x2xf32> to vector<2xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [0]>
+    }> : vector<2xf32>
   gpu.return
 }
 
@@ -449,6 +469,11 @@ gpu.func @vector_multi_reduction_dim1_distributed_dim0_reduction(%laneid: index)
         layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
       }
       [0] : vector<4x16xf32> to vector<16xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+    }> : vector<16xf32>
   gpu.return
 }
 
@@ -470,6 +495,11 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
         layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>
       }
       [1] : vector<16x12xf32> to vector<16xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, dims = [1]>
+    }> : vector<16xf32>
   gpu.return
 }
 
@@ -524,6 +554,11 @@ gpu.func @create_mask_1d(%m0: index) {
   %mask = vector.create_mask %m0
     {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
     : vector<16xi1>
+  %mask_cl = xegpu.convert_layout %mask
+    <{
+      input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }> : vector<16xi1>
   gpu.return
 }
 
@@ -539,6 +574,11 @@ gpu.func @constant_mask_1d() {
   %mask = vector.constant_mask [4]
     {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
     : vector<16xi1>
+  %mask_cl = xegpu.convert_layout %mask
+    <{
+      input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }> : vector<16xi1>
   gpu.return
 }
 
@@ -560,6 +600,11 @@ gpu.func @create_mask_2d(%m0: index, %m1: index) {
   %mask = vector.create_mask %m0, %m1
     {layout_result_0 = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>}
     : vector<8x4xi1>
+  %mask_cl = xegpu.convert_layout %mask
+    <{
+      input_layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
+    }> : vector<8x4xi1>
   gpu.return
 }
 
@@ -582,6 +627,11 @@ gpu.func @constant_mask_2d() {
   %mask = vector.constant_mask [2, 3]
     {layout_result_0 = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>}
     : vector<8x4xi1>
+      %mask_cl = xegpu.convert_layout %mask
+        <{
+          input_layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>,
+          target_layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
+        }> : vector<8x4xi1>
       gpu.return
 }
 
@@ -603,6 +653,11 @@ gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local() {
         layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
       }
       [1] : vector<1x16x32xf32> to vector<1x32xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
+    }> : vector<1x32xf32>
   gpu.return
 }
 
@@ -636,6 +691,11 @@ gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane() {
         layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>
       }
       [1] : vector<1x16x2xf32> to vector<1x2xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>
+    }> : vector<1x2xf32>
   gpu.return
 }
 
@@ -648,6 +708,11 @@ gpu.func @vector_extract_from_2d() {
   %0 = vector.extract %src[0]
     {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
     : vector<16xf32> from vector<4x16xf32>
+  %cl0 = xegpu.convert_layout %0
+    <{
+      input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }> : vector<16xf32>
   gpu.return
 }
 
@@ -660,6 +725,11 @@ gpu.func @vector_extract_from_2d_offset2() {
   %0 = vector.extract %src[2]
     {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
     : vector<16xf32> from vector<8x16xf32>
+  %cl0 = xegpu.convert_layout %0
+    <{
+      input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }> : vector<16xf32>
   gpu.return
 }
 
@@ -675,6 +745,11 @@ gpu.func @vector_insert_into_2d() {
   %0 = vector.insert %val, %dst[0]
     {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     : vector<16xf32> into vector<4x16xf32>
+  %cl0 = xegpu.convert_layout %0
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<4x16xf32>
   gpu.return
 }
 
@@ -690,6 +765,11 @@ gpu.func @vector_insert_into_2d_offset2() {
   %0 = vector.insert %val, %dst[2]
     {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     : vector<16xf32> into vector<8x16xf32>
+  %cl0 = xegpu.convert_layout %0
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<8x16xf32>
   gpu.return
 }
 
@@ -703,6 +783,11 @@ gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted() {
       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
     : vector<24x16xf32> to vector<8x16xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<8x16xf32>
   gpu.return
 }
 
@@ -716,6 +801,11 @@ gpu.func @vector_extract_strided_slice_inner_distributed() {
       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
     : vector<24x64xf32> to vector<8x16xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<8x16xf32>
   gpu.return
 }
 
@@ -729,6 +819,11 @@ gpu.func @vector_extract_strided_slice_outer_distributed() {
       layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
     }
     : vector<32x16xf32> to vector<16x16xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+    }> : vector<16x16xf32>
   gpu.return
 }
 
@@ -742,6 +837,11 @@ gpu.func @vector_extract_strided_slice_1d() {
       layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
     }
     : vector<64xf32> to vector<32xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }> : vector<32xf32>
   gpu.return
 }
 
@@ -755,6 +855,11 @@ gpu.func @vector_extract_strided_slice_partial_offsets() {
       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
     : vector<24x16xf32> to vector<8x16xf32>
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<8x16xf32>
   gpu.return
 }
 
@@ -771,6 +876,11 @@ gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted() {
       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
     : vector<16x16xf32> into vector<64x16xf32>
+  %cl2 = xegpu.convert_layout %2
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<64x16xf32>
   gpu.return
 }
 
@@ -787,6 +897,11 @@ gpu.func @vector_insert_strided_slice_inner_distributed() {
       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
     : vector<16x16xf32> into vector<64x32xf32>
+  %cl2 = xegpu.convert_layout %2
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<64x32xf32>
   gpu.return
 }
 
@@ -803,6 +918,11 @@ gpu.func @vector_insert_strided_slice_outer_distributed() {
       layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
     }
     : vector<16x16xf32> into vector<48x32xf32>
+  %cl2 = xegpu.convert_layout %2
+    <{
+      input_layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+    }> : vector<48x32xf32>
   gpu.return
 }
 
@@ -819,6 +939,11 @@ gpu.func @vector_insert_strided_slice_1d() {
       layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>
     }
     : vector<16xf32> into vector<48xf32>
+  %cl2 = xegpu.convert_layout %2
+    <{
+      input_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      target_layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+    }> : vector<48xf32>
   gpu.return
 }
 
@@ -835,6 +960,11 @@ gpu.func @vector_insert_strided_slice_different_ranks() {
       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
     : vector<16xf32> into vector<64x16xf32>
+  %cl2 = xegpu.convert_layout %2
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<64x16xf32>
   gpu.return
 }
 
@@ -953,6 +1083,11 @@ gpu.func @elementwise_wrap_around_dim() {
     : () -> vector<16x1xf16>
   %1 = arith.negf %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     : vector<16x1xf16>
+   %cl1 = xegpu.convert_layout %1
+     <{
+       input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+       target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+     }> : vector<16x1xf16>
    gpu.return
 }
 }
@@ -967,6 +1102,11 @@ gpu.module @xevm_module {
 // CHECK:         %[[VEC:.*]] = vector.from_elements %[[REM2]] : vector<1xindex>
 gpu.func @vector_step_slice() {
   %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : vector<16xindex>
+  %cl0 = xegpu.convert_layout %0
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>
+    }> : vector<16xindex>
   gpu.return
 }
 }
@@ -977,6 +1117,11 @@ gpu.module @xevm_module {
 // CHECK:         %[[VEC:.*]] = vector.from_elements %{{.*}} : vector<1xindex>
 gpu.func @vector_step_slice_unit() {
   %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 3]>} : vector<1xindex>
+  %cl0 = xegpu.convert_layout %0
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 3]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 3]>
+    }> : vector<1xindex>
   gpu.return
 }
 }
@@ -994,6 +1139,11 @@ gpu.module @xevm_module {
 // CHECK:         %[[VEC:.*]] = vector.from_elements %[[V0]], %[[V1]], %[[V2]], %[[V3]] : vector<4xindex>
 gpu.func @vector_step_slice_multi_dist() {
   %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [2, 4, 2], lane_data = [1, 2, 1]>, dims = [0, 2]>} : vector<16xindex>
+  %cl0 = xegpu.convert_layout %0
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [2, 4, 2], lane_data = [1, 2, 1]>, dims = [0, 2]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [2, 4, 2], lane_data = [1, 2, 1]>, dims = [0, 2]>
+    }> : vector<16xindex>
   gpu.return
 }
 }
@@ -1011,6 +1161,11 @@ gpu.func @vector_shapecast_rank_increasing() {
       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
     : vector<16xf32> to vector<1x16xf32>
+  %cast_cl = xegpu.convert_layout %cast
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<1x16xf32>
   gpu.return
 }
 }
@@ -1028,6 +1183,11 @@ gpu.func @vector_shapecast_rank_reducing() {
       layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
     }
     : vector<1x16xf32> to vector<16xf32>
+  %cast_cl = xegpu.convert_layout %cast
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+    }> : vector<16xf32>
   gpu.return
 }
 }
@@ -1045,6 +1205,11 @@ gpu.func @vector_shapecast_rank_increasing_without_slicing_layout() {
       layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
     }
     : vector<16xf32> to vector<1x16xf32>
+  %cast_cl = xegpu.convert_layout %cast
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<1x16xf32>
   gpu.return
 }
 }
@@ -1071,6 +1236,11 @@ gpu.module @xevm_module {
 gpu.func @constant_wrap_around_dim() {
   %0 = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     dense<1.0> : vector<16x1xf16>
+  %cl0 = xegpu.convert_layout %0
+    <{
+      input_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      target_layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }> : vector<16x1xf16>
   gpu.return
 }
 }
@@ -1160,6 +1330,11 @@ gpu.func @vector_multi_reduction_1d_to_scalar() {
         layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>
       }
       [0] : vector<32xf32> to f32
+  %cl1 = xegpu.convert_layout %1
+    <{
+      input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>,
+      target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [16], lane_data = [1]>, dims = [0]>
+    }> : f32
   gpu.return
 }
 }

>From df3106ac7ba17724f662c838bb3a8878624d1f11 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 9 Apr 2026 17:54:43 +0000
Subject: [PATCH 11/21] fix issues

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  2 +-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 87 ++++++++++++++++++-
 2 files changed, 84 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e001419257d8f..72584f40681f4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1494,7 +1494,7 @@ def XeGPU_FenceOp: XeGPU_Op<"fence", []> {
   let extraClassDeclaration = extraBaseClassDeclaration;
 }
 
-def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["source", "result"]>, AnchorLayoutInterface]> {
+def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [AllTypesMatch<["source", "result"]>, MemoryEffects<[MemRead, MemWrite]>, AnchorLayoutInterface]> {
     let summary = "Convert the layout of the input operand";
     let description = [{
       `convert_layout` redistribute data across subgroups and/or lanes from the `input_layout` to
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 6aeabacf4b30f..c53e865729eb2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -396,6 +396,10 @@ class LayoutInfoPropagation
                            ArrayRef<LayoutInfoLattice *> operands,
                            ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitConvertLayoutOp(xegpu::ConvertLayoutOp convertLayout,
+                            ArrayRef<LayoutInfoLattice *> operands,
+                            ArrayRef<const LayoutInfoLattice *> results);
+
   bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
 
 public:
@@ -483,6 +487,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
         visitStoreMatrixOp(storeMatrixOp, operands, results);
       })
+      .Case([&](xegpu::ConvertLayoutOp convertLayoutOp) {
+        visitConvertLayoutOp(convertLayoutOp, operands, results);
+      })
       // All other ops.
       .Default([&](Operation *op) {
         for (const LayoutInfoLattice *resultInfo : results) {
@@ -936,6 +943,17 @@ void LayoutInfoPropagation::visitLoadNdOp(
   propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
 }
 
+/// Propagate the layout of the value to the tensor descriptor operand in
+/// ConvertLayoutOp.
+void LayoutInfoPropagation::visitConvertLayoutOp(
+    xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  xegpu::DistributeLayoutAttr anchorLayout = convert.getTargetLayoutAttr();
+  LayoutInfo convertLayout(anchorLayout);
+  // Propagate the new layout to the tensor descriptor operand.
+  propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
+}
+
 /// For vector::TransposeOp, the layout of the result is transposed and
 /// propagated to the operand.
 void LayoutInfoPropagation::visitTransposeOp(
@@ -1346,7 +1364,7 @@ LogicalResult ResolveLayoutConflicts::run() {
     // as anchor op for the reduction op's layout.
     if (isa<vector::MultiDimReductionOp>(op) || isa<vector::ReductionOp>(op)) {
       for (OpResult result : op->getResults()) {
-        if (result.getType().isIntOrFloat() || result.use_empty()) {
+        if (result.getType().isIntOrFloat()) {
           auto res = assignResultLayout(result);
           if (failed(res)) {
             DBGS() << "Failed to resolve vector consumer for multi-reduction "
@@ -1356,6 +1374,22 @@ LogicalResult ResolveLayoutConflicts::run() {
         }
       }
     }
+    if (isa<RegionBranchOpInterface>(op)) {
+      auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
+      unsigned numResults = regionBranchOp->getNumResults();
+      for (unsigned i = 0; i < numResults; ++i) {
+        OpResult result = regionBranchOp->getResult(i);
+        if (result.use_empty()) {
+          auto res = assignResultLayout(result);
+          if (failed(res)) {
+            DBGS() << "Failed to resolve vector consumer for loop/switch "
+                      "result with no use: "
+                   << *op << "\n";
+            return WalkResult::interrupt();
+          }
+        }
+      }
+    }
     for (OpOperand &operand : op->getOpOperands()) {
       // Handle conflicts in tensor descriptor operands.
       Type operandType = operand.get().getType();
@@ -1369,6 +1403,8 @@ LogicalResult ResolveLayoutConflicts::run() {
         }
       }
       // Handle conflicts in vector operands.
+      LLVM_DEBUG(DBGS() << "Handling vector operand #" << operand.getOperandNumber()
+                        << ": " << operand.get() << " in operation: " << *op << "\n");
       if (isa<VectorType>(operandType)) {
         auto res = resolveVectorConsumer(operand);
         if (failed(res)) {
@@ -1507,7 +1543,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
     // If the result is a vector type, add a temporary layout attribute to the
     // op.
-    xegpu::setDistributeLayoutAttr(result, layout);
+    // xegpu::setDistributeLayoutAttr(result, layout);
   }
   return success();
 }
@@ -1537,7 +1573,8 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
 static LogicalResult
 updateControlFlowOps(mlir::OpBuilder &builder,
                      mlir::RegionBranchTerminatorOpInterface terminator,
-                     GetLayoutFnTy getLayoutOfValue) {
+                     GetLayoutFnTy getLayoutOfValue,
+                     xegpu::LayoutKind layoutKind) {
   LLVM_DEBUG(DBGS() << "updateControlFlowOps: processing terminator: "
                     << *terminator << "\n");
   // Only process if the terminator is inside a region branch op.
@@ -1570,6 +1607,12 @@ updateControlFlowOps(mlir::OpBuilder &builder,
             DBGS() << "    skipping: not a TensorDescType or VectorType\n");
         continue;
       }
+
+      // debug print successorInput and successorOperand
+      LLVM_DEBUG(DBGS() << "    successor input: " << successorInput << "\n");
+      LLVM_DEBUG(DBGS() << "    successor operand: " << successorOperand->get()
+                        << "\n");
+
       xegpu::DistributeLayoutAttr successorInputLayout =
           getLayoutOfValue(successorInput);
       xegpu::DistributeLayoutAttr successorOperandLayout =
@@ -1601,6 +1644,15 @@ updateControlFlowOps(mlir::OpBuilder &builder,
       }
       // Get tensor descriptor type with the layout.
       if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
+        // if (successorInputLayout != successorOperandLayout) {
+        //   LLVM_DEBUG(DBGS()
+        //              << "    FAILURE: Conflicting layouts for region "
+        //                 "argument and operand forwarded as the argument: "
+        //              << successorInputLayout << " vs " <<
+        //              successorOperandLayout
+        //              << "\n");
+        //   return failure();
+        // }
         auto newTdescTy = xegpu::TensorDescType::get(
             tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
             tdescTy.getEncoding(), successorOperandLayout);
@@ -1609,6 +1661,22 @@ updateControlFlowOps(mlir::OpBuilder &builder,
         successorInput.setType(newTdescTy);
         continue;
       }
+
+      // if (auto vectorTy = dyn_cast<VectorType>(inputType)) {
+      //   SmallVector<int64_t> vectorShape(vectorTy.getShape().begin(),
+      //                                    vectorTy.getShape().end());
+      //   if (!successorInputLayout.isCompatibleWith(successorOperandLayout,
+      //                                              vectorShape, layoutKind))
+      //                                              {
+      //     LLVM_DEBUG(DBGS()
+      //                << "    FAILURE: Conflicting layouts for region "
+      //                   "argument and operand forwarded as the argument: "
+      //                << successorInputLayout << " vs " <<
+      //                successorOperandLayout
+      //                << "\n");
+      //     return failure();
+      //   }
+      // }
       // If the type is a vector type and this region argument is an OpResult,
       // set the layout attribute on the OpResult.
       if (auto result = dyn_cast<OpResult>(successorInput)) {
@@ -1720,7 +1788,7 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
       TypeSwitch<Operation *>(&op)
           .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
             r = updateControlFlowOps(builder, branchTermOp,
-                                     getXeGPULayoutForValue);
+                                     getXeGPULayoutForValue, layoutKind);
           })
           .Case([&](mlir::FunctionOpInterface funcOp) {
             r = updateFunctionOpInterface(builder, funcOp,
@@ -1748,6 +1816,17 @@ LogicalResult xegpu::resolveLayoutConflicts(Operation *target) {
 }
 
 void XeGPUPropagateLayoutPass::runOnOperation() {
+  // Remove layout attributes from SCF ops
+  getOperation()->walk([](Operation *op) {
+    SmallVector<StringAttr> attrsToRemove;
+    for (auto namedAttr : op->getDiscardableAttrs()) {
+      if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
+        attrsToRemove.push_back(namedAttr.getName());
+    }
+    for (auto attrName : attrsToRemove)
+      op->removeDiscardableAttr(attrName);
+  });
+
   xegpu::LayoutKind layoutKind;
   if (this->layoutKind == "lane") {
     layoutKind = xegpu::LayoutKind::Lane;

>From ad40b1f79e13b9def43a29ca53e8096308f894d9 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 9 Apr 2026 20:56:45 +0000
Subject: [PATCH 12/21] fix convert_layout pattern

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp        | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index c53e865729eb2..84ecaf34606cd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -948,7 +948,7 @@ void LayoutInfoPropagation::visitLoadNdOp(
 void LayoutInfoPropagation::visitConvertLayoutOp(
     xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  xegpu::DistributeLayoutAttr anchorLayout = convert.getTargetLayoutAttr();
+  xegpu::DistributeLayoutAttr anchorLayout = convert.getInputLayoutAttr();
   LayoutInfo convertLayout(anchorLayout);
   // Propagate the new layout to the tensor descriptor operand.
   propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
@@ -1543,7 +1543,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
     }
     // If the result is a vector type, add a temporary layout attribute to the
     // op.
-    // xegpu::setDistributeLayoutAttr(result, layout);
+    xegpu::setDistributeLayoutAttr(result, layout);
   }
   return success();
 }
@@ -1780,8 +1780,10 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
 
     return cast<xegpu::LayoutAttr>(layoutAttr);
   };
-
+  // dump the op before update
+  llvm::dbgs() << "Before layout propagation and conflict resolution:\n";
   Operation *op = target;
+  op->dump();
   auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
     for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
       LogicalResult r = success();
@@ -1804,6 +1806,10 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
     }
     return WalkResult::advance();
   });
+  // dump the op after update, print some headline
+  llvm::dbgs() << "After layout propagation and conflict resolution:\n";
+  op->dump();
+
   if (walkResult.wasInterrupted())
     return failure();
 

>From e2f581b14bb8a3f4a5277febe05c1e906a9325bf Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 10 Apr 2026 06:12:47 +0000
Subject: [PATCH 13/21] fix, all test passed now

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 57 ++++++++++++-------
 1 file changed, 37 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 84ecaf34606cd..4f1680648fcdf 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1374,22 +1374,22 @@ LogicalResult ResolveLayoutConflicts::run() {
         }
       }
     }
-    if (isa<RegionBranchOpInterface>(op)) {
-      auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
-      unsigned numResults = regionBranchOp->getNumResults();
-      for (unsigned i = 0; i < numResults; ++i) {
-        OpResult result = regionBranchOp->getResult(i);
-        if (result.use_empty()) {
-          auto res = assignResultLayout(result);
-          if (failed(res)) {
-            DBGS() << "Failed to resolve vector consumer for loop/switch "
-                      "result with no use: "
-                   << *op << "\n";
-            return WalkResult::interrupt();
-          }
-        }
-      }
-    }
+    // if (isa<RegionBranchOpInterface>(op)) {
+    //   auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
+    //   unsigned numResults = regionBranchOp->getNumResults();
+    //   for (unsigned i = 0; i < numResults; ++i) {
+    //     OpResult result = regionBranchOp->getResult(i);
+    //     if (result.use_empty()) {
+    //       auto res = assignResultLayout(result);
+    //       if (failed(res)) {
+    //         DBGS() << "Failed to resolve vector consumer for loop/switch "
+    //                   "result with no use: "
+    //                << *op << "\n";
+    //         return WalkResult::interrupt();
+    //       }
+    //     }
+    //   }
+    // }
     for (OpOperand &operand : op->getOpOperands()) {
       // Handle conflicts in tensor descriptor operands.
       Type operandType = operand.get().getType();
@@ -1403,8 +1403,9 @@ LogicalResult ResolveLayoutConflicts::run() {
         }
       }
       // Handle conflicts in vector operands.
-      LLVM_DEBUG(DBGS() << "Handling vector operand #" << operand.getOperandNumber()
-                        << ": " << operand.get() << " in operation: " << *op << "\n");
+      LLVM_DEBUG(DBGS() << "Handling vector operand #"
+                        << operand.getOperandNumber() << ": " << operand.get()
+                        << " in operation: " << *op << "\n");
       if (isa<VectorType>(operandType)) {
         auto res = resolveVectorConsumer(operand);
         if (failed(res)) {
@@ -1758,8 +1759,7 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
   // Helper to convert LayoutInfo to xegpu::LayoutAttr.
   auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
     LayoutInfo layout = analysis.getLayoutInfo(val);
-    if (!layout.isAssigned())
-      return {};
+
     if (auto opResult = dyn_cast<OpResult>(val)) {
 
       Operation *defOp = opResult.getDefiningOp();
@@ -1773,6 +1773,8 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
       if (requiredResLayoutAttr != nullptr)
         return requiredResLayoutAttr;
     }
+    if (!layout.isAssigned())
+      return {};
     xegpu::DistributeLayoutAttr layoutAttr =
         cast<xegpu::DistributeLayoutAttr>(layout.get());
     if (layout.isSliceLayout())
@@ -1787,6 +1789,21 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
   auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
     for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
       LogicalResult r = success();
+
+      //  for (OpOperand &operand : op.getOpOperands()) {
+      //    Type operandType = operand.get().getType();
+      //    // We only need to operate on tensor descriptor or vector types.
+      //    if (!isa<xegpu::TensorDescType, VectorType>(operandType))
+      //      continue;
+      //    xegpu::DistributeLayoutAttr consumerLayout =
+      //        getXeGPULayoutForValue(operand.get());
+      //    // dump the layout info for debugging
+      //    // dump operation and operand info for debugging
+      //    DBGS() << "Processing operation: " << op << "\n";
+      //    llvm::outs() << "operand #" << operand.getOperandNumber() << " = "
+      //                 << operand.get() << "\n";
+      //    DBGS() << "Layout: " << consumerLayout << "\n";
+      //  }
       TypeSwitch<Operation *>(&op)
           .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
             r = updateControlFlowOps(builder, branchTermOp,

>From cb4aeb2b3b0ec20ef180147ff610d03ca93ec550 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 20:39:57 +0000
Subject: [PATCH 14/21] remove separated PR

---
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |   10 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |    2 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |   40 +-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  233 +--
 .../XeGPUSgToWiDistributeExperimental.cp      | 1744 +++++++++++++++++
 .../XeGPUSgToWiDistributeExperimental.cpp     |    1 -
 .../Transforms/XeGPUSubgroupDistribute.cpp    |   41 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      |    1 -
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |    1 -
 9 files changed, 1799 insertions(+), 274 deletions(-)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index c36201c2f0d9e..651b001a9d931 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -112,11 +112,10 @@ inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
                                     ArrayRef<int64_t> resShape,
                                     ArrayRef<int64_t> srcShape);
 
-/// Infers the layout attribute for mask and offset operand for Chunked load
-/// and store, given the anchor layout attribute for the value being load/store.
+/// Infers the source layout attribute for an operand using result layout
+/// attribute
 DistributeLayoutAttr
-inferMaskOffsetLayoutForScatterIO(DistributeLayoutAttr payloadLayout,
-                                  int chunkSize);
+inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
 
 /// Sets up layout for Multi-Reduction operations by creating a SliceAttr for
 /// the result.
@@ -189,9 +188,6 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
                 VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
                 const uArch::uArch *uArch);
 
-DistributeLayoutAttr
-inferSourceLayoutFromResult(OpOperand &operand, DistributeLayoutAttr resLayout);
-
 /// Gets the expected layout for a given consumer operand. This will check if
 /// the owning operation of the consumer operand is one of the special layout
 /// users and determine the expected layout accordingly.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eaa43c02946d8..64c56b5adf5d7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -585,7 +585,7 @@ DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
     int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
     newOrder.push_back(d - offset);
   }
-  if ((sgLayout.empty() && laneLayout.empty()) || newOrder.size() == 1)
+  if (sgLayout.empty() && laneLayout.empty())
     newOrder.clear();
 
   auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 47c60eaf7d4e0..65fab3e60166d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -26,11 +26,11 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
 
+#include "llvm/Support/Debug.h"
 #define DEBUG_TYPE "xegpu-layout-recovery"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
@@ -375,33 +375,6 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
   return true;
 }
 
-// // Attach layout attributes to all vector-type operands of operations within
-// // the given operation's region. Reports an error if any vector operand lacks
-// // a layout attribute.
-// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-//   auto result = rootOp->walk([&](Operation *op) {
-//     for (OpOperand &operand : op->getOpOperands()) {
-//       // Layouts are needed for vector type only.
-//       if (!isa<VectorType>(operand.get().getType()))
-//         continue;
-//       // Skip block arguments since they don't have defining ops to attach
-//       // layout attributes to.
-//       if (isa<BlockArgument>(operand.get()))
-//         continue;
-//       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
-//       if (!layout) {
-//         op->emitWarning("Could not find layout attribute for operand ")
-//             << operand.getOperandNumber() << " of operation " <<
-//             op->getName();
-//         xegpu::setTemporaryLayout(operand, layout);
-//         continue;
-//       }
-//     }
-//     return WalkResult::advance();
-//   });
-//   return !result.wasInterrupted();
-// }
-
 template <typename T, typename>
 void xegpu::removeLayoutAttr(const T &operandOrResult) {
   Operation *owner = operandOrResult.getOwner();
@@ -677,17 +650,6 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   return nullptr;
 }
 
-/// Infers the layout attribute for mask and offset operand for Chunked load
-/// and store, given the anchor layout attribute for the value being load/store.
-xegpu::DistributeLayoutAttr xegpu::inferMaskOffsetLayoutForScatterIO(
-    xegpu::DistributeLayoutAttr payloadLayout, int chunkSize) {
-  auto rank = payloadLayout.getRank();
-  if (chunkSize > 1)
-    return payloadLayout.dropDims(
-        llvm::to_vector(llvm::seq<int64_t>(rank - 1, rank)));
-  return payloadLayout;
-}
-
 /// Sets up layout for reduction operations by creating a SliceAttr for the
 /// result.
 ///
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4f1680648fcdf..4c30dacae8850 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -396,10 +396,6 @@ class LayoutInfoPropagation
                            ArrayRef<LayoutInfoLattice *> operands,
                            ArrayRef<const LayoutInfoLattice *> results);
 
-  void visitConvertLayoutOp(xegpu::ConvertLayoutOp convertLayout,
-                            ArrayRef<LayoutInfoLattice *> operands,
-                            ArrayRef<const LayoutInfoLattice *> results);
-
   bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
 
 public:
@@ -487,9 +483,6 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case([&](xegpu::StoreMatrixOp storeMatrixOp) {
         visitStoreMatrixOp(storeMatrixOp, operands, results);
       })
-      .Case([&](xegpu::ConvertLayoutOp convertLayoutOp) {
-        visitConvertLayoutOp(convertLayoutOp, operands, results);
-      })
       // All other ops.
       .Default([&](Operation *op) {
         for (const LayoutInfoLattice *resultInfo : results) {
@@ -943,17 +936,6 @@ void LayoutInfoPropagation::visitLoadNdOp(
   propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
 }
 
-/// Propagate the layout of the value to the tensor descriptor operand in
-/// ConvertLayoutOp.
-void LayoutInfoPropagation::visitConvertLayoutOp(
-    xegpu::ConvertLayoutOp convert, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  xegpu::DistributeLayoutAttr anchorLayout = convert.getInputLayoutAttr();
-  LayoutInfo convertLayout(anchorLayout);
-  // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(convertLayout));
-}
-
 /// For vector::TransposeOp, the layout of the result is transposed and
 /// propagated to the operand.
 void LayoutInfoPropagation::visitTransposeOp(
@@ -1045,7 +1027,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
   const uArch *uArch = getUArch(getChipStr(load).value_or(""));
   if (!uArch)
     return;
-  // auto subgroupSize = uArch->getSubgroupSize();
+  auto subgroupSize = uArch->getSubgroupSize();
   VectorType resVecTy = load.getValueType();
   int chunkSize = load.getChunkSize().value_or(1);
 
@@ -1067,24 +1049,20 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     load.setLayoutAttr(requiredAnchorLayoutAttr);
   }
 
-  assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
-  auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
-      requiredAnchorLayoutAttr, chunkSize);
-
-  // // Special handling mask layout for chunked ops: Enforce the default xegpu
-  // 1D
-  // // layout for mask.
-  // if (chunkSize > 1) {
-  //   if (layoutKind == xegpu::LayoutKind::InstData)
-  //     maskLayoutAttr =
-  //         xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
-  //   else if (layoutKind == xegpu::LayoutKind::Lane)
-  //     maskLayoutAttr =
-  //         xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
-  //   else
-  //     assert(false &&
-  //            "chunked StoreScatterOp should not be used at workgroup level");
-  // }
+  auto maskLayoutAttr = requiredAnchorLayoutAttr;
+  // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
+  // layout for mask.
+  if (chunkSize > 1) {
+    if (layoutKind == xegpu::LayoutKind::InstData)
+      maskLayoutAttr =
+          xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
+    else if (layoutKind == xegpu::LayoutKind::Lane)
+      maskLayoutAttr =
+          xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
+    else
+      assert(false &&
+             "chunked StoreScatterOp should not be used at workgroup level");
+  }
 
   LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
   auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
@@ -1127,7 +1105,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   const uArch *uArch = getUArch(getChipStr(storeScatter).value_or(""));
   if (!uArch)
     return;
-  // auto subgroupSize = uArch->getSubgroupSize();
+  auto subgroupSize = uArch->getSubgroupSize();
   VectorType srcVecTy = storeScatter.getValueType();
   int chunkSize = storeScatter.getChunkSize().value_or(1);
 
@@ -1144,26 +1122,22 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   }
 
   LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
-  assert((chunkSize <= 1) || (layoutKind != xegpu::LayoutKind::Subgroup));
-  auto maskLayoutAttr = xegpu::inferMaskOffsetLayoutForScatterIO(
-      requiredAnchorLayoutAttr, chunkSize);
-  LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+  auto maskLayoutAttr = requiredAnchorLayoutAttr;
+  // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
+  // layout for mask.
+  if (chunkSize > 1) {
+    if (layoutKind == xegpu::LayoutKind::InstData)
+      maskLayoutAttr =
+          xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
+    else if (layoutKind == xegpu::LayoutKind::Lane)
+      maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
+                                              {subgroupSize}, {1});
+    else
+      assert(false &&
+             "chunked StoreScatterOp should not be used at workgroup level");
+  }
 
-  // auto maskLayoutAttr = requiredAnchorLayoutAttr;
-  // // Special handling mask layout for chunked ops: Enforce the default xegpu
-  // 1D
-  // // layout for mask.
-  // if (chunkSize > 1) {
-  //   if (layoutKind == xegpu::LayoutKind::InstData)
-  //     maskLayoutAttr =
-  //         xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
-  //   else if (layoutKind == xegpu::LayoutKind::Lane)
-  //     maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
-  //                                             {subgroupSize}, {1});
-  //   else
-  //     assert(false &&
-  //            "chunked StoreScatterOp should not be used at workgroup level");
-  // }
+  LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
 
   // Propagate the payload operand layout
   propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
@@ -1374,22 +1348,6 @@ LogicalResult ResolveLayoutConflicts::run() {
         }
       }
     }
-    // if (isa<RegionBranchOpInterface>(op)) {
-    //   auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
-    //   unsigned numResults = regionBranchOp->getNumResults();
-    //   for (unsigned i = 0; i < numResults; ++i) {
-    //     OpResult result = regionBranchOp->getResult(i);
-    //     if (result.use_empty()) {
-    //       auto res = assignResultLayout(result);
-    //       if (failed(res)) {
-    //         DBGS() << "Failed to resolve vector consumer for loop/switch "
-    //                   "result with no use: "
-    //                << *op << "\n";
-    //         return WalkResult::interrupt();
-    //       }
-    //     }
-    //   }
-    // }
     for (OpOperand &operand : op->getOpOperands()) {
       // Handle conflicts in tensor descriptor operands.
       Type operandType = operand.get().getType();
@@ -1403,9 +1361,6 @@ LogicalResult ResolveLayoutConflicts::run() {
         }
       }
       // Handle conflicts in vector operands.
-      LLVM_DEBUG(DBGS() << "Handling vector operand #"
-                        << operand.getOperandNumber() << ": " << operand.get()
-                        << " in operation: " << *op << "\n");
       if (isa<VectorType>(operandType)) {
         auto res = resolveVectorConsumer(operand);
         if (failed(res)) {
@@ -1574,125 +1529,56 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
 static LogicalResult
 updateControlFlowOps(mlir::OpBuilder &builder,
                      mlir::RegionBranchTerminatorOpInterface terminator,
-                     GetLayoutFnTy getLayoutOfValue,
-                     xegpu::LayoutKind layoutKind) {
-  LLVM_DEBUG(DBGS() << "updateControlFlowOps: processing terminator: "
-                    << *terminator << "\n");
+                     GetLayoutFnTy getLayoutOfValue) {
   // Only process if the terminator is inside a region branch op.
   auto branchOp = dyn_cast<RegionBranchOpInterface>(terminator->getParentOp());
-  if (!branchOp) {
-    LLVM_DEBUG(
-        DBGS() << "  parent is not a RegionBranchOpInterface, skipping\n");
+  if (!branchOp)
     return success();
-  }
-  LLVM_DEBUG(DBGS() << "  parent branch op: " << *branchOp << "\n");
 
   RegionBranchSuccessorMapping mapping;
   branchOp.getSuccessorOperandInputMapping(mapping,
                                            RegionBranchPoint(terminator));
-  LLVM_DEBUG(DBGS() << "  successor mapping has " << mapping.size()
-                    << " entries\n");
   for (const auto &[successorOperand, successorInputs] : mapping) {
-    LLVM_DEBUG(DBGS() << "  processing successor operand: "
-                      << successorOperand->get()
-                      << " (type: " << successorOperand->get().getType()
-                      << "), num successor inputs: " << successorInputs.size()
-                      << "\n");
     for (Value successorInput : successorInputs) {
       Type inputType = successorInput.getType();
-      LLVM_DEBUG(DBGS() << "    successor input: " << successorInput
-                        << ", type: " << inputType << "\n");
       // We only need to operate on tensor descriptor or vector types.
-      if (!isa<xegpu::TensorDescType, VectorType>(inputType)) {
-        LLVM_DEBUG(
-            DBGS() << "    skipping: not a TensorDescType or VectorType\n");
+      if (!isa<xegpu::TensorDescType, VectorType>(inputType))
         continue;
-      }
-
-      // debug print successorInput and successorOperand
-      LLVM_DEBUG(DBGS() << "    successor input: " << successorInput << "\n");
-      LLVM_DEBUG(DBGS() << "    successor operand: " << successorOperand->get()
-                        << "\n");
-
       xegpu::DistributeLayoutAttr successorInputLayout =
           getLayoutOfValue(successorInput);
       xegpu::DistributeLayoutAttr successorOperandLayout =
           getLayoutOfValue(successorOperand->get());
 
-      LLVM_DEBUG(DBGS() << "    successor input layout: ");
-      LLVM_DEBUG(if (successorInputLayout) llvm::dbgs() << successorInputLayout;
-                 else llvm::dbgs() << "<<NULL>>"; llvm::dbgs() << "\n");
-      LLVM_DEBUG(DBGS() << "    successor operand layout: ");
-      LLVM_DEBUG(if (successorOperandLayout) llvm::dbgs()
-                     << successorOperandLayout;
-                 else llvm::dbgs() << "<<NULL>>"; llvm::dbgs() << "\n");
-
       // If either of the layouts is not assigned, we cannot proceed.
       if (!successorOperandLayout) {
-        LLVM_DEBUG(DBGS() << "    FAILURE: No layout assigned for forwarded "
-                             "operand in branch terminator: "
+        LLVM_DEBUG(DBGS() << "No layout assigned for forwarded operand in "
+                             "branch terminator: "
                           << successorOperand->get() << "\n");
         return failure();
       }
       // We expect the layouts to match.
       if (successorInputLayout &&
           successorInputLayout != successorOperandLayout) {
-        LLVM_DEBUG(DBGS() << "    FAILURE: Conflicting layouts for region "
-                             "argument and operand forwarded as the argument: "
+        LLVM_DEBUG(DBGS() << "Conflicting layouts for region argument and "
+                             "operand forwarded as the argument: "
                           << successorInputLayout << " vs "
                           << successorOperandLayout << "\n");
         return failure();
       }
       // Get tensor descriptor type with the layout.
       if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(inputType)) {
-        // if (successorInputLayout != successorOperandLayout) {
-        //   LLVM_DEBUG(DBGS()
-        //              << "    FAILURE: Conflicting layouts for region "
-        //                 "argument and operand forwarded as the argument: "
-        //              << successorInputLayout << " vs " <<
-        //              successorOperandLayout
-        //              << "\n");
-        //   return failure();
-        // }
         auto newTdescTy = xegpu::TensorDescType::get(
             tdescTy.getContext(), tdescTy.getShape(), tdescTy.getElementType(),
             tdescTy.getEncoding(), successorOperandLayout);
-        LLVM_DEBUG(DBGS() << "    updating tensor desc type: " << tdescTy
-                          << " -> " << newTdescTy << "\n");
         successorInput.setType(newTdescTy);
         continue;
       }
-
-      // if (auto vectorTy = dyn_cast<VectorType>(inputType)) {
-      //   SmallVector<int64_t> vectorShape(vectorTy.getShape().begin(),
-      //                                    vectorTy.getShape().end());
-      //   if (!successorInputLayout.isCompatibleWith(successorOperandLayout,
-      //                                              vectorShape, layoutKind))
-      //                                              {
-      //     LLVM_DEBUG(DBGS()
-      //                << "    FAILURE: Conflicting layouts for region "
-      //                   "argument and operand forwarded as the argument: "
-      //                << successorInputLayout << " vs " <<
-      //                successorOperandLayout
-      //                << "\n");
-      //     return failure();
-      //   }
-      // }
       // If the type is a vector type and this region argument is an OpResult,
       // set the layout attribute on the OpResult.
-      if (auto result = dyn_cast<OpResult>(successorInput)) {
-        LLVM_DEBUG(DBGS() << "    setting layout on OpResult #"
-                          << result.getResultNumber() << " of "
-                          << *result.getOwner() << " to "
-                          << successorOperandLayout << "\n");
+      if (auto result = dyn_cast<OpResult>(successorInput))
         xegpu::setDistributeLayoutAttr(result, successorOperandLayout);
-      } else {
-        LLVM_DEBUG(DBGS() << "    successor input is a BlockArgument, "
-                             "not setting layout attribute\n");
-      }
     }
   }
-  LLVM_DEBUG(DBGS() << "  updateControlFlowOps: success\n");
   return success();
 }
 
@@ -1759,7 +1645,8 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
   // Helper to convert LayoutInfo to xegpu::LayoutAttr.
   auto getXeGPULayoutForValue = [&](Value val) -> xegpu::DistributeLayoutAttr {
     LayoutInfo layout = analysis.getLayoutInfo(val);
-
+    if (!layout.isAssigned())
+      return {};
     if (auto opResult = dyn_cast<OpResult>(val)) {
 
       Operation *defOp = opResult.getDefiningOp();
@@ -1773,8 +1660,6 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
       if (requiredResLayoutAttr != nullptr)
         return requiredResLayoutAttr;
     }
-    if (!layout.isAssigned())
-      return {};
     xegpu::DistributeLayoutAttr layoutAttr =
         cast<xegpu::DistributeLayoutAttr>(layout.get());
     if (layout.isSliceLayout())
@@ -1782,32 +1667,15 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
 
     return cast<xegpu::LayoutAttr>(layoutAttr);
   };
-  // dump the op before update
-  llvm::dbgs() << "Before layout propagation and conflict resolution:\n";
+
   Operation *op = target;
-  op->dump();
   auto walkResult = op->walk([&](mlir::Block *block) -> WalkResult {
     for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
       LogicalResult r = success();
-
-      //  for (OpOperand &operand : op.getOpOperands()) {
-      //    Type operandType = operand.get().getType();
-      //    // We only need to operate on tensor descriptor or vector types.
-      //    if (!isa<xegpu::TensorDescType, VectorType>(operandType))
-      //      continue;
-      //    xegpu::DistributeLayoutAttr consumerLayout =
-      //        getXeGPULayoutForValue(operand.get());
-      //    // dump the layout info for debugging
-      //    // dump operation and operand info for debugging
-      //    DBGS() << "Processing operation: " << op << "\n";
-      //    llvm::outs() << "operand #" << operand.getOperandNumber() << " = "
-      //                 << operand.get() << "\n";
-      //    DBGS() << "Layout: " << consumerLayout << "\n";
-      //  }
       TypeSwitch<Operation *>(&op)
           .Case([&](mlir::RegionBranchTerminatorOpInterface branchTermOp) {
             r = updateControlFlowOps(builder, branchTermOp,
-                                     getXeGPULayoutForValue, layoutKind);
+                                     getXeGPULayoutForValue);
           })
           .Case([&](mlir::FunctionOpInterface funcOp) {
             r = updateFunctionOpInterface(builder, funcOp,
@@ -1823,10 +1691,6 @@ LogicalResult xegpu::propagateLayouts(OpBuilder &builder, Operation *target,
     }
     return WalkResult::advance();
   });
-  // dump the op after update, print some headline
-  llvm::dbgs() << "After layout propagation and conflict resolution:\n";
-  op->dump();
-
   if (walkResult.wasInterrupted())
     return failure();
 
@@ -1839,17 +1703,6 @@ LogicalResult xegpu::resolveLayoutConflicts(Operation *target) {
 }
 
 void XeGPUPropagateLayoutPass::runOnOperation() {
-  // Remove layout attributes from SCF ops
-  getOperation()->walk([](Operation *op) {
-    SmallVector<StringAttr> attrsToRemove;
-    for (auto namedAttr : op->getDiscardableAttrs()) {
-      if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
-        attrsToRemove.push_back(namedAttr.getName());
-    }
-    for (auto attrName : attrsToRemove)
-      op->removeDiscardableAttr(attrName);
-  });
-
   xegpu::LayoutKind layoutKind;
   if (this->layoutKind == "lane") {
     layoutKind = xegpu::LayoutKind::Lane;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
new file mode 100644
index 0000000000000..e3227c7f5b149
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
@@ -0,0 +1,1744 @@
+//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/Value.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/Support/LogicalResult.h"
+#include "llvm/Support/raw_ostream.h"
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+namespace {
+
+/// Casts the given vector value `v` to the expected vector type `expectedTy`.
+static Value castValueTo(ConversionPatternRewriter &rewriter,
+                         TypedValue<VectorType> v, VectorType expectedTy) {
+  // If the type matches, simply return the value itself.
+  if (v.getType() == expectedTy)
+    return v;
+  // If only shape differs, use shape cast.
+  if (isa<VectorType>(v.getType()) &&
+      v.getType().getNumElements() == expectedTy.getNumElements())
+    return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
+
+  // Else create an unrealized cast.
+  auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
+                                                  expectedTy, ValueRange{v});
+  return newOp.getResult(0);
+}
+
+/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
+/// exactly 1 reduction dimension, it had valid result layout attribute, and
+/// result type can be distributed to lanes using the layout.
+static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
+  auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+  // If no layout, not valid.
+  if (!resLayout || !resLayout.isForSubgroup())
+    return false;
+  // Scalar result (e.g., vector<32xf32> to f32) is valid.
+  if (op.getType().isIntOrFloat())
+    return op.getReductionDims().size() == 1;
+  VectorType resTy = dyn_cast<VectorType>(op.getType());
+  if (!resTy)
+    return false;
+  // Compute the distributed result vector type based on the layout.
+  FailureOr<VectorType> resDistTypeOrFailure =
+      getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+  if (failed(resDistTypeOrFailure))
+    return false;
+  return op.getReductionDims().size() == 1;
+}
+
+/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
+/// is doing its own local reduction. In this case the result layout ensures
+/// that result vector is distributed to lanes, i.e. the result vector type is
+/// different from the distributed result vector type.
+static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
+  // Must be valid MultiDimReductionOp.
+  assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
+                                                "MultiDimReductionOp");
+  auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+  VectorType resTy = dyn_cast<VectorType>(op.getType());
+  auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
+  return resTy != resDistTypeOrFailure.value();
+}
+
+/// Given a vector type and its distributed vector type, return the list of
+/// dimensions that are distributed.
+static SmallVector<int64_t> getDistributedDims(VectorType originalType,
+                                               VectorType distributedType) {
+  assert(originalType.getRank() == distributedType.getRank() &&
+         "original and distributed vector types must have the same rank");
+  SmallVector<int64_t> distributedDims;
+  for (int64_t i = 0; i < originalType.getRank(); ++i) {
+    if (distributedType.getDimSize(i) != originalType.getDimSize(i))
+      distributedDims.push_back(i);
+  }
+  return distributedDims;
+}
+
+/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
+/// op. This simply drops the layout attribute from the tensor descriptor type.
+struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::TensorDescType resultType = op.getType();
+    // If no layout, nothing to do.
+    if (!resultType.getLayout())
+      return failure();
+
+    auto newOp = xegpu::CreateNdDescOp::create(
+        rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
+        op->getAttrs());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
+/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
+/// original rank.
+struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+    // Check if the layout attached to the tensor descriptor is same as the
+    // anchor layout. Otherwise, this is a conflict.
+    if (op.getTensorDescType().getLayout() != layout)
+      return rewriter.notifyMatchFailure(
+          op, "conflicting layout attributes on tensor descriptor and anchor");
+    auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+    if (!uArch)
+      return rewriter.notifyMatchFailure(
+          op, "xegpu::LoadNdOp require target attribute attached to "
+              "determine transpose "
+              "requirement");
+    auto supportedWiResultTyOrFailure =
+        xegpu::getDistributedVectorType(op.getTensorDescType());
+    auto expectedWiResultTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
+    if (failed(supportedWiResultTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute the workitem vector type for LoadNdOp");
+    if (failed(expectedWiResultTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute expected workitem vector type from lane layout");
+    auto newOp = xegpu::LoadNdOp::create(
+        rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
+        adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
+        op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+        op.getL3HintAttr(), /**layout**/ nullptr);
+    // Set the packed attribute if the layout requires it.
+    newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
+    // Set the transpose attribute if the layout requires it.
+    if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
+      newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
+    rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+                                       expectedWiResultTyOrFailure.value()));
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
+/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
+/// incoming value to 1D.
+struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
+  using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+    // Check if the layout attached to the tensor descriptor and value layout is
+    // same as the anchor layout. Otherwise, this is a conflict.
+    if (op.getTensorDescType().getLayout() != layout)
+      return rewriter.notifyMatchFailure(
+          op, "conflicting layout attributes on tensor descriptor and anchor");
+    auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
+    if (valueLayout != layout)
+      return rewriter.notifyMatchFailure(
+          op, "conflicting layout attributes on value and anchor");
+    auto supportedWiValueTyOrFailure =
+        xegpu::getDistributedVectorType(op.getTensorDescType());
+    if (failed(supportedWiValueTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute wi vector type for StoreNdOp value from tensor "
+          "descriptor");
+
+    xegpu::StoreNdOp::create(
+        rewriter, op.getLoc(),
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
+                    supportedWiValueTyOrFailure.value()),
+        adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
+        op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
+/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
+/// convert the inputs and output to/from 1D.
+struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
+  using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
+    // Check if the op has A, B and CD layouts attached.
+    auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
+    auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
+    auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
+    if (!layoutA || !layoutB || !layoutCd)
+      return failure();
+    // llvm::errs() << "tryning to calculate wi types for dpas op\n";
+    auto wiResultTyOrFailure =
+        xegpu::getDistributedVectorType(op.getType(), layoutCd);
+    auto wiATypeOrFailure =
+        xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
+    auto wiBTypeOrFailure =
+        xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
+    auto expectedWiResultTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
+    if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
+        failed(wiBTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "failed to calculate supported workitem vector types for DpasOp "
+              "from layouts");
+    if (failed(expectedWiResultTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute expected workitem vector type for DpasOp from "
+              "lane layout");
+    auto newOp = xegpu::DpasOp::create(
+        rewriter, op->getLoc(), wiResultTyOrFailure.value(),
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
+                    wiATypeOrFailure.value()),
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
+                    wiBTypeOrFailure.value()),
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
+                    wiResultTyOrFailure.value()),
+        /** layoutA**/ nullptr,
+        /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
+    // Explicitly set the new types to enable correct type materializations.
+    rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+                                       expectedWiResultTyOrFailure.value()));
+    return success();
+  }
+};
+
+/// Distributes elementwise ops to workitem-level elementwise ops. This
+/// currently handles elementwise ops with single result only.
+struct SgToWiElementWise : public ConversionPattern {
+  SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
+      : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only match ops with elementwise trait and single result.
+    if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
+      return failure();
+
+    auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(
+          op, "operation result is not a vector type");
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
+    if (!layout || !layout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "operation result does not have subgroup distribute layout");
+
+    auto wiShapeOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
+
+    if (failed(wiShapeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute workitem vector type from the layout");
+
+    VectorType newResultType = wiShapeOrFailure.value();
+    OperationState state(op->getLoc(), op->getName());
+    state.addOperands(operands);
+    state.addTypes(newResultType);
+    // Copy all attributes except for DistributeLayoutAttr.
+    for (auto attr : op->getAttrs()) {
+      if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
+        state.addAttribute(attr.getName(), attr.getValue());
+    }
+    Operation *newOp = rewriter.create(state);
+
+    rewriter.replaceOp(op, newOp->getResult(0));
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
+/// ConstantOp.
+struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto resultType = dyn_cast<VectorType>(op.getType());
+    if (!resultType)
+      return failure();
+
+    // Only handle dense vector constants
+    auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
+    if (!dense)
+      return rewriter.notifyMatchFailure(
+          op, "only dense splat vector constants are supported");
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
+    if (!layout || !layout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "operation result does not have subgroup distribute layout");
+
+    auto wiShapeOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
+
+    if (failed(wiShapeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute workitem vector type from the layout");
+
+    VectorType newResultType = wiShapeOrFailure.value();
+    auto sclarValue = dense.getSplatValue<Attribute>();
+    auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
+
+    auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
+                                           newDenseAttr);
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
+struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
+  using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+
+    xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
+                                op.getMixedOffsets(), op.getL1HintAttr(),
+                                op.getL2HintAttr(), op.getL3HintAttr(),
+                                /**layout**/ nullptr);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+///
+/// Example 1 (1D, no chunk size):
+///   layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+///   %mask = producer_op : vector<16xi1>
+///   %offset = producer_op : vector<16xindex>
+///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+///     vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// Distributed to:
+///   %mask = producer_op : vector<1xi1>
+///   %offset = producer_op : vector<1xindex>
+///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+///     vector<1xindex>, vector<1xi1> -> vector<1xf16>
+///
+/// Example 2 (2D with chunk size, same mask & offset):
+///   layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+///   %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
+///     memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// Distributed to:
+///   %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
+///     memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+///
+/// Example 3 (3D with leading unit dims):
+///   layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
+///   %mask = producer_op : vector<1x1x16xi1>
+///   %offset = producer_op : vector<1x1x16xindex>
+///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+///     vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
+/// Distributed to:
+///   %mask = producer_op : vector<1x1x1xi1>
+///   %offset = producer_op : vector<1x1x1xindex>
+///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
+///     vector<1xindex>, vector<1xi1> -> vector<1xf16>
+struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
+  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    if (!layout)
+      return failure();
+
+    VectorType origResultTy = op.getValueType();
+    if (!origResultTy)
+      return failure();
+
+    // Check that leading dimensions are unit.
+    int chunkSize = op.getChunkSize().value_or(1);
+    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+    ArrayRef<int64_t> shape = origResultTy.getShape();
+    if (llvm::any_of(
+            shape.take_front(origResultTy.getRank() - effectiveVecRank),
+            [](int64_t d) { return d != 1; }))
+      return rewriter.notifyMatchFailure(
+          op, "Only unit dimensions allowed for the leading "
+              "dimensions of the load vector!");
+
+    auto distResultTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
+    if (failed(distResultTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute expected workitem vector type from lane layout");
+
+    VectorType distResultTy = distResultTyOrFailure.value();
+    VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
+                                                distResultTy.getElementType());
+
+    // Flatten offsets and mask to 1D to match the 1D result type.
+    Value distOffsets = adaptor.getOffsets();
+    auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
+    VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
+                                             distOffsetsTy.getElementType());
+    distOffsets = castValueTo(
+        rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
+
+    Value distMask = adaptor.getMask();
+    auto distMaskTy = cast<VectorType>(distMask.getType());
+    VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
+                                          distMaskTy.getElementType());
+    distMask =
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
+
+    Value distSource = adaptor.getSource();
+    auto newOp = xegpu::LoadGatherOp::create(
+        rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
+        distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
+        op.getL3HintAttr(), /*layout=*/nullptr);
+
+    Value result = newOp->getResult(0);
+    if (distResultTy1D != distResultTy)
+      result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
+                           distResultTy);
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/// This pattern distributes a subgroup-level vector.reduction op to
+/// workitem-level. This require shuffling the data across the workitems (using
+/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
+/// result.
+struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
+  using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+
+    // If no layout, nothing to do.
+    if (!layout || !layout.isForSubgroup())
+      return failure();
+
+    VectorType srcVecType = op.getSourceVectorType();
+    // Only rank 1 vectors supported.
+    if (srcVecType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "Only rank 1 reductions can be distributed.");
+    // Lane layout must have the same rank as the vector.
+    if (layout.getRank() != srcVecType.getRank())
+      return rewriter.notifyMatchFailure(
+          op, "Layout rank does not match vector rank.");
+
+    // Get the subgroup size from the layout.
+    int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
+    const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+    if (!uArch)
+      return rewriter.notifyMatchFailure(
+          op, "xegpu::ReductionOp require target attribute attached to "
+              "determine subgroup size");
+
+    // Only subgroup-sized vectors supported.
+    if (sgSize != uArch->getSubgroupSize() ||
+        srcVecType.getShape()[0] % sgSize != 0)
+      return rewriter.notifyMatchFailure(op,
+                                         "Invalid layout or reduction vector "
+                                         "dimension must match subgroup size.");
+
+    if (!op.getType().isIntOrFloat())
+      return rewriter.notifyMatchFailure(
+          op, "Reduction distribution currently only supports floats and "
+              "integer types.");
+
+    // Get the distributed vector (per work-item portion).
+    Value laneValVec = adaptor.getVector();
+
+    // Distribute and reduce across work-items in the subgroup.
+    Value fullReduce = xegpu::subgroupReduction(
+        op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
+
+    // If there's an accumulator, combine it with the reduced value.
+    if (adaptor.getAcc())
+      fullReduce = vector::makeArithReduction(
+          rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
+
+    rewriter.replaceOp(op, fullReduce);
+    return success();
+  }
+};
+
+/// This pattern distributes a subgroup-level vector.multi_reduction op to
+/// workitem-level only if the reduction is lane-local. This means that
+/// reduction dimension is not distributed to lanes and each lane does its own
+/// local reduction.
+struct SgToWiMultiDimReduction
+    : public OpConversionPattern<vector::MultiDimReductionOp> {
+  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value result;
+    ArrayRef<int64_t> reductionDims = op.getReductionDims();
+    assert(reductionDims.size() == 1 &&
+           "Expecting single reduction dimension for subgroup multi "
+           "reduction op");
+    // For rank > 2, ensure leading dimensions are unit.
+    VectorType sourceType = op.getSourceVectorType();
+    int64_t rank = sourceType.getRank();
+    if (rank > 2) {
+      ArrayRef<int64_t> shape = sourceType.getShape();
+      if (llvm::any_of(shape.take_front(rank - 2),
+                       [](int64_t d) { return d != 1; }))
+        return rewriter.notifyMatchFailure(
+            op, "only unit leading dimensions are supported for "
+                "multi_reduction with rank > 2");
+    }
+    // Handle scalar result: full reduction of a distributed vector to a
+    // scalar. First do a local vector reduction, then cross-lane shuffles.
+    if (op.getType().isIntOrFloat()) {
+      auto reductionDim = reductionDims[0];
+      VectorType origSourceType = op.getSourceVectorType();
+      int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
+      // Local reduction to scalar, then cross-lane butterfly shuffles.
+      result =
+          xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
+                                   op.getKind(), reductionDimSize);
+      // Combine with accumulator if present.
+      if (adaptor.getAcc())
+        result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
+                                            result, adaptor.getAcc());
+    } else if (isReductionLaneLocal(op)) {
+      auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
+      VectorType resVecTy = dyn_cast<VectorType>(op.getType());
+      auto resDistVecTyOrFailure =
+          getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
+      // For lane local reduction, simply create a new MultiDimReductionOp using
+      // adaptor operands and the new result type.
+      result = vector::MultiDimReductionOp::create(
+          rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
+          adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
+    } else {
+      auto reductionDim = reductionDims[0];
+      VectorType sourceType = op.getSourceVectorType();
+      int64_t reductionDimSize = sourceType.getShape()[reductionDim];
+      result = xegpu::lowerCrossLaneReductionToShuffles(
+          cast<TypedValue<VectorType>>(adaptor.getSource()),
+          cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
+          reductionDim, reductionDimSize, op.getLoc(), rewriter);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/// Helper to compute distributed coordinates for matrix ops.
+/// When not using subgroup_block_io, each workitem computes its own
+/// coordinates based on the layout and lane ID.
+static SmallVector<Value> computeDistributedCoordsForMatrixOp(
+    ConversionPatternRewriter &rewriter, Location loc,
+    xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
+    ValueRange origOffsets) {
+  Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+                                       /*upperBound=*/mlir::IntegerAttr());
+  auto maybeCoords =
+      layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+  if (failed(maybeCoords))
+    return {};
+  assert(maybeCoords.value().size() == 1 &&
+         "Expected one set of distributed offsets");
+  SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+      rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+      getAsOpFoldResult(origOffsets));
+  return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
+}
+
+/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
+struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
+  using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = op.getLayoutAttr();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+
+    VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          op, "the matrix op payload must be a vector type");
+
+    auto loc = op.getLoc();
+    auto offsets = op.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(op, "the load op must have offsets");
+
+    FailureOr<VectorType> distPayloadTyOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to distribute matrix op payload based on layout.");
+
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, loc, offsets);
+
+    SmallVector<Value> newCoords = offsetsAsValues;
+    if (!op.getSubgroupBlockIoAttr()) {
+      newCoords = computeDistributedCoordsForMatrixOp(
+          rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+      if (newCoords.empty())
+        return rewriter.notifyMatchFailure(
+            op, "Failed to compute distributed coordinates.");
+    }
+
+    SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+                                         ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+    auto newOp = xegpu::LoadMatrixOp::create(
+        rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
+        ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
+        xegpu::DistributeLayoutAttr{});
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.transpose op to workitem-level.
+struct SgToWiVectorTranspose : public OpConversionPattern<vector::TransposeOp> {
+  using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getTemporaryLayout(op->getOpOperand(0));
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!sourceLayout || !resultLayout)
+      return rewriter.notifyMatchFailure(
+          op, "the source or result vector of the transpose op lacks layout "
+              "attribute");
+    ArrayRef<int64_t> perm = op.getPermutation();
+    // Result layout must be a transpose of source layout.
+    if (!resultLayout.isTransposeOf(sourceLayout, perm,
+                                    xegpu::LayoutKind::Lane))
+      return rewriter.notifyMatchFailure(
+          op, "the source or result vector layouts must be transposes of "
+              "each other");
+    FailureOr<VectorType> distributedResultTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
+    if (failed(distributedResultTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to distribute the result vector type in "
+              "vector::Transpose op");
+    auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
+                                             adaptor.getVector(), perm);
+    rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
+                                       distributedResultTypeOrFailure.value()));
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.bitcast op to workitem-level.
+/// Bitcast only impacts the innermost dimension of the source/result vectors.
+struct SgToWiVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
+  using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!resultLayout)
+      return rewriter.notifyMatchFailure(
+          op, "result vector of the bitcast op lacks layout attribute");
+    FailureOr<VectorType> distributedResultTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
+    if (failed(distributedResultTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to distribute the result vector type in "
+              "vector::BitCast op");
+    auto newOp = vector::BitCastOp::create(
+        rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
+        adaptor.getSource());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
+/// to workitem-level. Uses `computeDistributedCoords()` to obtain the
+/// coordinates each workitem owns, then compares each coordinate against the
+/// original mask bounds using `arith.cmpi slt`. The per-element boolean
+/// results are assembled into the distributed mask vector.
+///
+/// For multi-dimensional masks, the element is in-bounds when ALL dimensions
+/// satisfy `coord[i] < bound[i]`.
+///
+/// Example (1D):
+///   layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+///   %mask = vector.create_mask %m0 : vector<16xi1>
+/// For lane k, computeDistributedCoords gives coord = [k], so:
+///   %in_bounds = arith.cmpi slt, %coord, %m0  →  i1
+///   %mask = vector.broadcast %in_bounds : i1 to vector<1xi1>
+///
+/// Example (2D):
+///   layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
+///   %mask = vector.create_mask %m0, %m1 : vector<8x4xi1>
+/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords:
+///   [[r0, c0], [r0, c1]]
+/// For each coord: in_bounds = (r < m0) && (c < m1)
+///   %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1>
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
+struct SgToWiCreateMask : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!layout || !layout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "operation result does not have subgroup distribute layout");
+
+    VectorType origType = op.getType();
+    FailureOr<VectorType> distTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, origType);
+    if (failed(distTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute workitem vector type from the layout");
+
+    VectorType distType = distTypeOrFailure.value();
+    Location loc = op.getLoc();
+
+    // Materialize the original mask bounds as Values.
+    SmallVector<Value> origBounds;
+    if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
+      origBounds.append(op.getOperands().begin(), op.getOperands().end());
+    } else {
+      auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
+      for (auto dimSize : dimSizes)
+        origBounds.push_back(
+            arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
+    }
+
+    ArrayRef<int64_t> origShape = origType.getShape();
+
+    // Use computeDistributedCoords to get the coordinates each WI owns.
+    Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+                                         /*upperBound=*/mlir::IntegerAttr());
+    auto maybeCoordsVec =
+        layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
+    if (failed(maybeCoordsVec))
+      return rewriter.notifyMatchFailure(
+          op, "failed to compute distributed coordinates from layout");
+
+    SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
+    int64_t numElements = distType.getNumElements();
+    assert(static_cast<int64_t>(coordsVec.size()) == numElements &&
+           "number of coordinate sets must match number of distributed "
+           "elements");
+
+    // For each element, compare all coordinates against bounds.
+    Value trueVal =
+        arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1);
+    SmallVector<Value> maskBits;
+    for (auto &coords : coordsVec) {
+      Value inBounds = trueVal;
+      for (size_t i = 0; i < coords.size(); ++i) {
+        Value cmp = arith::CmpIOp::create(
+            rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
+        inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
+      }
+      maskBits.push_back(inBounds);
+    }
+
+    // Build the distributed mask vector.
+    Value result;
+    if (numElements == 1) {
+      result =
+          vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
+    } else {
+      result =
+          vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
+    }
+    rewriter.replaceOp(op, result);
+    return success();
+  }
+};
+
+/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
+struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
+  using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto layout = op.getLayoutAttr();
+    // If no layout, nothing to do.
+    if (!layout)
+      return failure();
+
+    VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          op, "the matrix op payload must be a vector type");
+
+    auto loc = op.getLoc();
+    auto offsets = op.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(op, "the store op must have offsets");
+
+    FailureOr<VectorType> distPayloadTyOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "Failed to distribute matrix op payload based on layout.");
+
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, loc, offsets);
+
+    SmallVector<Value> newCoords = offsetsAsValues;
+    if (!op.getSubgroupBlockIoAttr()) {
+      newCoords = computeDistributedCoordsForMatrixOp(
+          rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
+      if (newCoords.empty())
+        return rewriter.notifyMatchFailure(
+            op, "Failed to compute distributed coordinates.");
+    }
+
+    SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
+                                         ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+
+    xegpu::StoreMatrixOp::create(
+        rewriter, loc, TypeRange{},
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
+                    distPayloadTyOrFailure.value()),
+        adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
+        op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
+/// workitem-level.
+///
+/// Example 1 (1D, no chunk size):
+///   layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
+///   %mask = producer_op : vector<16xi1>
+///   %offset = producer_op : vector<16xindex>
+///   xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+///     memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// Distributed to:
+///   %mask = producer_op : vector<1xi1>
+///   %offset = producer_op : vector<1xindex>
+///   xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+///     memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example 2 (2D with chunk size, same mask & offset):
+///   layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
+///   xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// Distributed to:
+///   xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+///
+/// Example 3 (3D with leading unit dims):
+///   layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
+///   %mask = producer_op : vector<1x1x16xi1>
+///   %offset = producer_op : vector<1x1x16xindex>
+///   xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
+///     memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
+/// Distributed to:
+///   %mask = producer_op : vector<1x1x1xi1>
+///   %offset = producer_op : vector<1x1x1xindex>
+///   xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+///     memref<256xf16>, vector<1xindex>, vector<1xi1>
+struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
+  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+    if (!layout)
+      return failure();
+
+    VectorType origValueTy = op.getValueType();
+    if (!origValueTy)
+      return failure();
+
+    // Check that all leading dimensions are unit dimensions.
+    int chunkSize = op.getChunkSize().value_or(1);
+    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+    ArrayRef<int64_t> shape = origValueTy.getShape();
+    if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
+                     [](int64_t d) { return d != 1; }))
+      return rewriter.notifyMatchFailure(
+          op, "Only unit dimensions allowed for the leading "
+              "dimensions of the store vector!");
+
+    auto distValueTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
+    if (failed(distValueTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op,
+          "unable to compute expected workitem vector type from lane layout");
+
+    VectorType distValueTy = distValueTyOrFailure.value();
+    VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
+                                               distValueTy.getElementType());
+
+    Value distValue = adaptor.getValue();
+    if (distValue.getType() != distValueTy1D)
+      distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
+                              distValueTy1D);
+
+    // Flatten offsets and mask to 1D to match the 1D value type.
+    Value distOffsets = adaptor.getOffsets();
+    auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
+    VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
+                                             distOffsetsTy.getElementType());
+    distOffsets = castValueTo(
+        rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
+
+    Value distMask = adaptor.getMask();
+    auto distMaskTy = cast<VectorType>(distMask.getType());
+    VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
+                                          distMaskTy.getElementType());
+    distMask =
+        castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
+
+    Value distDest = adaptor.getDest();
+    xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
+                                  distOffsets, distMask, op.getChunkSizeAttr(),
+                                  op.getL1HintAttr(), op.getL2HintAttr(),
+                                  op.getL3HintAttr(), /*layout=*/nullptr);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// Distribute a vector::StepOp to workitem-level.
+/// The layout must have exactly 1 effective lane dimension.
+/// We completely resolve the vector::StepOp by computing the lane_data-sized
+/// subranges.
+struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getResult(0));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "the result vector of the step op lacks subgroup layout");
+
+    auto loc = op.getLoc();
+    auto stepResultVecTy = op.getResult().getType();
+    auto wiShapeOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
+    if (failed(wiShapeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute workitem vector type from the layout");
+    VectorType newVecTy = wiShapeOrFailure.value();
+
+    Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+                                         /*upperBound=*/mlir::IntegerAttr());
+    auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
+        rewriter, loc, laneId, stepResultVecTy.getShape());
+    if (failed(laneDataBlockCoords))
+      return rewriter.notifyMatchFailure(
+          op, "failed to compute lane data block coordinates");
+
+    auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
+    auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
+    assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
+           newVecTy.getNumElements() / laneDataBlockLength);
+    SmallVector<Value> stepVals;
+    // For each lane_data block, reconstruct its sub-range
+    // from the range of SG-level vector.step.Example: vector.step
+    // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
+    // vector<16xindex>
+    // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
+    // The blocks are round-robin distributed, so logical lane id 0
+    // holds values [0,1, 8,9].
+    for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
+      auto laneDataBlockStartCoord = laneDataBlockCoords[0];
+      stepVals.push_back(laneDataBlockStartCoord);
+      for (int i = 1; i < laneDataBlockLength; ++i) {
+        auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
+        stepVals.push_back(arith::AddIOp::create(
+            rewriter, loc, laneDataBlockStartCoord, offset));
+      }
+    }
+    assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
+           "Expecting the number of step values to match the number of "
+           "elements in the vector");
+    auto stepOpVal =
+        vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
+    rewriter.replaceOp(op, stepOpVal);
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.extract op to workitem-level. Only
+/// handles sub-vector extraction (result is VectorType, not scalar).
+struct SgToWiVectorExtract : public OpConversionPattern<vector::ExtractOp> {
+  using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only handle vector results (not scalar extraction).
+    auto resultType = dyn_cast<VectorType>(op.getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(op, "scalar extract not supported");
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!layout || !layout.isForSubgroup())
+      return failure();
+
+    // This implementation assumes distribution only happens on the innermost
+    // dimension. Verify that lane_layout[0...n-2] are all unit.
+    auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+    if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
+                     [](int64_t v) { return v != 1; }))
+      return rewriter.notifyMatchFailure(
+          op, "only innermost dimension distribution is supported for "
+              "vector.extract");
+
+    auto newOp = vector::ExtractOp::create(
+        rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
+struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
+  using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "the result vector of the shape_cast op lacks subgroup layout");
+
+    auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
+        resultLayout, op.getResultVectorType());
+    if (failed(resultDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "failed to get distributed vector type for result");
+
+    Value source = adaptor.getSource();
+    auto newShapeCast = vector::ShapeCastOp::create(
+        rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
+    rewriter.replaceOp(op, newShapeCast);
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.extract_strided_slice op to
+/// workitem-level. If the result is distributed, the offsets and sizes are
+/// adjusted to match the distributed types.
+struct SgToWiVectorExtractStridedSlice
+    : public OpConversionPattern<vector::ExtractStridedSliceOp> {
+  using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return failure();
+
+    VectorType resultType = op.getType();
+    auto distResultTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
+    if (failed(distResultTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute distributed vector type from lane layout");
+    VectorType distResultTy = *distResultTyOrFailure;
+
+    SmallVector<int64_t> distributedDims =
+        getDistributedDims(resultType, distResultTy);
+
+    // Collect updated sizes, offsets, strides. Pad to full source rank.
+    int64_t sourceRank = op.getSourceVectorType().getRank();
+    SmallVector<Attribute> updatedSizes =
+        llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
+    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+        op.getOffsets(), [](Attribute attr) { return attr; });
+    SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
+        op.getStrides(), [](Attribute attr) { return attr; });
+    for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
+      updatedSizes.push_back(
+          rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
+      updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
+      updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
+    }
+
+    // If the result is distributed, adjust offsets and sizes in the
+    // distributed dimension.
+    if (!distributedDims.empty()) {
+      if (distributedDims.size() != 1)
+        return rewriter.notifyMatchFailure(
+            op, "only single dimension distribution is supported");
+      int64_t distDim = distributedDims[0];
+      const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+      if (!uArch)
+        return rewriter.notifyMatchFailure(
+            op, "target attribute required to determine subgroup size");
+      int subgroupSize = uArch->getSubgroupSize();
+      auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
+      if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+        return rewriter.notifyMatchFailure(
+            op, "source of extract_strided_slice lacks distribution layout");
+      int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
+      if (sourceDistrDimSize % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            op, "source size along distributed dim is not a multiple of "
+                "subgroup size");
+      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+      // Only check lane_data for the distributed dimension. Non-distributed
+      // dimensions may have non-unit lane_data (e.g., packed layouts).
+      if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
+          sourceLaneData[distDim] != 1)
+        return rewriter.notifyMatchFailure(
+            op, "expecting unit lane data along the distributed dimension");
+      int64_t distrDimOffset =
+          cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
+      if (distrDimOffset % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            op, "offset along distributed dim is not a multiple of "
+                "subgroup size");
+      // Adjust sizes and offsets for the distributed dimension.
+      updatedSizes[distDim] =
+          rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
+      updatedOffsets[distDim] =
+          rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
+    }
+
+    auto newOp = vector::ExtractStridedSliceOp::create(
+        rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
+        ArrayAttr::get(rewriter.getContext(), updatedOffsets),
+        ArrayAttr::get(rewriter.getContext(), updatedSizes),
+        ArrayAttr::get(rewriter.getContext(), updatedStrides));
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// This pattern distributes a subgroup-level `vector.broadcast` op to
+/// workitem-level. The pattern supports three cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+///    vector must have a slice layout of the result. If the distributed source
+///    and target vector types are identical, this lowers to a no-op; otherwise,
+///    it remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+///    target: The source vector must have unit dimensions, and lane_data must
+///    be unit size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
+///    from scalar to distributed result type.
+///
+/// Example 1 (low-rank to high-rank broadcast):
+/// ```
+///   %0 = "some_op"() {layout_result_0 =
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+///     dims = [0]>} : () -> vector<16xf16>
+///   %1 = vector.broadcast %0 {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+///     : vector<16xf16> to vector<16x16xf16>
+/// ```
+/// is distributed to:
+/// ```
+///   %0 = "some_op"() : () -> vector<1xf16>
+///   %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
+/// ```
+///
+/// Example 2 (same-rank broadcast, no-op):
+/// ```
+///   %0 = "some_op"() {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+///     : () -> vector<16x1xf16>
+///   %1 = vector.broadcast %0 {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+///     : vector<16x1xf16> to vector<16x16xf16>
+/// ```
+/// is distributed to (no-op, source already matches distributed result type):
+/// ```
+///   %0 = "some_op"() : () -> vector<16x1xf16>
+///   // broadcast is eliminated, %0 is used directly
+/// ```
+///
+/// Example 3 (scalar to vector broadcast):
+/// ```
+///   %0 = "some_op"() : () -> f16
+///   %1 = vector.broadcast %0 {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+///     : f16 to vector<16x16xf16>
+/// ```
+/// is distributed to:
+/// ```
+///   %0 = "some_op"() : f16
+///   %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
+/// ```
+struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "result does not have subgroup distribute layout");
+
+    VectorType destType = op.getResultVectorType();
+    VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
+
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getTemporaryLayout(op->getOpOperand(0));
+
+    if (sourceType) {
+      int64_t rankDiff = destType.getRank() - sourceType.getRank();
+      if (rankDiff > 0) {
+        // Case 1: Low-rank to high-rank broadcast.
+        if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
+          op.emitWarning(
+              "broadcast source layout must be a slice of result layout");
+      } else if (rankDiff == 0) {
+        // Case 2: Same-rank broadcast.
+        auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
+        SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
+                                               broadcastUnitDimsSet.end());
+        assert(sourceLayout.isEqualTo(
+                   sourceLayout.setUnitDimData(broadcastUnitDims)) &&
+               "The sg_data for unit dimensions should be set as 1");
+        sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
+      }
+    } else {
+      // Case 3: Scalar to vector broadcast.
+      if (sourceLayout)
+        return rewriter.notifyMatchFailure(
+            op, "broadcast from scalar must not have a layout attribute");
+    }
+
+    auto destDistType =
+        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+    if (failed(destDistType))
+      return rewriter.notifyMatchFailure(
+          op, "failed to distribute the result vector type");
+
+    Value source = adaptor.getSource();
+    // If the adapted source already matches the dest dist type, it's a no-op.
+    if (source.getType() == destDistType.value()) {
+      rewriter.replaceOp(op, source);
+      return success();
+    }
+
+    auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
+                                             destDistType.value(), source);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.insert_strided_slice op to
+/// workitem-level. If the dest is distributed, the offsets are adjusted to
+/// match the distributed types.
+struct SgToWiVectorInsertStridedSlice
+    : public OpConversionPattern<vector::InsertStridedSliceOp> {
+  using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return failure();
+
+    VectorType destType = op.getDestVectorType();
+    auto distDestTyOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+    if (failed(distDestTyOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute distributed vector type from lane layout");
+    VectorType distDestTy = *distDestTyOrFailure;
+
+    SmallVector<int64_t> destDistributedDims =
+        getDistributedDims(destType, distDestTy);
+
+    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
+        op.getOffsets(), [](Attribute attr) { return attr; });
+
+    if (!destDistributedDims.empty()) {
+      if (destDistributedDims.size() != 1)
+        return rewriter.notifyMatchFailure(
+            op, "only single dimension distribution is supported");
+      int64_t destDistDim = destDistributedDims[0];
+
+      const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
+      if (!uArch)
+        return rewriter.notifyMatchFailure(
+            op, "target attribute required to determine subgroup size");
+      int subgroupSize = uArch->getSubgroupSize();
+
+      VectorType srcType = op.getSourceVectorType();
+      // The distributed dim must be in the last k (source rank) dims of dest.
+      int64_t sourceDistDim =
+          destDistDim - (destType.getRank() - srcType.getRank());
+      if (sourceDistDim < 0)
+        return rewriter.notifyMatchFailure(
+            op, "distributed dimension must be in the last k dims of dest");
+
+      auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
+      auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
+      if (!destLayout || !sourceLayout ||
+          destLayout.getEffectiveLaneLayoutAsInt().empty() ||
+          sourceLayout.getEffectiveLaneLayoutAsInt().empty())
+        return rewriter.notifyMatchFailure(
+            op, "source or dest of insert_strided_slice lacks distribution "
+                "layout");
+
+      auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
+      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
+      // Only check lane_data for the distributed dimension. Non-distributed
+      // dimensions may have non-unit lane_data (e.g., packed layouts).
+      if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
+           destLaneData[destDistDim] != 1) ||
+          (sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
+           sourceLaneData[sourceDistDim] != 1))
+        return rewriter.notifyMatchFailure(
+            op, "expecting unit lane data along the distributed dimension");
+
+      int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
+      if (srcDistrDimSize % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            op, "source distributed dim size is not a multiple of "
+                "subgroup size");
+
+      int64_t destDistrDimOffset =
+          cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
+      if (destDistrDimOffset % subgroupSize != 0)
+        return rewriter.notifyMatchFailure(
+            op, "offset along distributed dim is not a multiple of "
+                "subgroup size");
+      // Adjust offset for the distributed dimension.
+      updatedOffsets[destDistDim] =
+          rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
+    }
+
+    auto newOp = vector::InsertStridedSliceOp::create(
+        rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
+        adaptor.getDest(),
+        ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// Distributes a subgroup-level vector.insert op to workitem-level. Only
+/// handles sub-vector insertion (value to store is VectorType, not scalar).
+struct SgToWiVectorInsert : public OpConversionPattern<vector::InsertOp> {
+  using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only handle vector value-to-store (not scalar insertion).
+    auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
+    if (!valueType)
+      return rewriter.notifyMatchFailure(op, "scalar insert not supported");
+
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!layout || !layout.isForSubgroup())
+      return failure();
+
+    // verify that the outer k dimensions (for offsets)
+    // don't have non-unit lane_layout.
+    auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+    if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
+                     [](int64_t v) { return v != 1; }))
+      return rewriter.notifyMatchFailure(
+          op, "only innermost dimension distribution is supported for "
+              "vector.insert");
+
+    auto newOp = vector::InsertOp::create(
+        rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
+        op.getMixedPosition());
+    rewriter.replaceOp(op, newOp.getResult());
+    return success();
+  }
+};
+
+/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
+struct SgToWiConvertLayout
+    : public OpConversionPattern<xegpu::ConvertLayoutOp> {
+  using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto inputLayout = op.getInputLayoutAttr();
+    auto targetLayout = op.getTargetLayoutAttr();
+    Type valType = op.getResult().getType();
+
+    if (valType.isIntOrFloat()) {
+      rewriter.replaceOp(op, op.getSource());
+      return success();
+    }
+
+    auto resShape = cast<VectorType>(valType).getShape();
+    SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
+    if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
+                                      xegpu::LayoutKind::Lane)) {
+      return rewriter.notifyMatchFailure(
+          op, "lowering incompatible convert_layout not yet supported");
+    }
+
+    rewriter.replaceOp(op, adaptor.getSource());
+    return success();
+  }
+};
+
+struct XeGPUSgToWiDistributeExperimentalPass
+    : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
+          XeGPUSgToWiDistributeExperimentalPass> {
+  void runOnOperation() override;
+};
+
+} // namespace
+
+void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
+
+  // Recover temporary operand layouts for usage in patterns.
+  Operation *root = getOperation();
+  if (!xegpu::recoverTemporaryLayouts(root)) {
+    signalPassFailure();
+    return;
+  }
+
+  // Collect existing UnrealizedConversionCastOps. These must be preserved.
+  llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
+  root->walk(
+      [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
+  // Perform a structural type conversion to convert structural ops to have WI
+  // types. This will insert UnrealizedConversionCastOps to make the IR
+  // valid.
+  auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
+                             mlir::ValueRange inputs,
+                             mlir::Location loc) -> mlir::Value {
+    UnrealizedConversionCastOp castOp =
+        UnrealizedConversionCastOp::create(builder, loc, type, inputs);
+    return castOp.getResult(0);
+  };
+  {
+    ConversionTarget target(getContext());
+    TypeConverter typeConverter;
+    RewritePatternSet patterns(&getContext());
+    typeConverter.addSourceMaterialization(materializeCast);
+    typeConverter.addTargetMaterialization(materializeCast);
+    xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
+    scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
+                                                         patterns, target);
+    xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
+        typeConverter, patterns, target);
+    target.addLegalOp<UnrealizedConversionCastOp>();
+    (void)applyPartialConversion(root, target, std::move(patterns));
+  }
+  // Structural type conversion can generate some redundant
+  // UnrealizedConversionCastOps to materialize the SG type from type converted
+  // WI type. These are redundant at this point and can be eliminated by
+  // inserting shape casts instead.
+  // Example:
+  // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
+  // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
+  // This can be replaced with:
+  // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
+  OpBuilder builder(root);
+  root->walk([&](UnrealizedConversionCastOp op) {
+    // If this op existed before, nothing to do.
+    if (existingCasts.contains(op))
+      return;
+    // number of inputs and outputs must be 1.
+    if (op.getNumOperands() != 1 || op.getNumResults() != 1)
+      return;
+    // Both input and output types must be vector types.
+    auto singleInput = op.getInputs()[0];
+    auto inputTy = dyn_cast<VectorType>(singleInput.getType());
+    auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
+    if (!inputTy || !outputTy)
+      return;
+
+    // Check if the defining op of the input is also an
+    // UnrealizedConversionCastOp and it has a single user (which is this
+    // op).
+    auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
+    if (!definingOp || !definingOp->hasOneUse())
+      return;
+    auto inputOfDefiningOp = definingOp.getInputs()[0];
+    // If the input of the defining op and output type are both vector types
+    // have same number of elements, insert a shape cast.
+    auto inputOfDefiningOpTy =
+        dyn_cast<VectorType>(inputOfDefiningOp.getType());
+    if (inputOfDefiningOpTy &&
+        inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
+      builder.setInsertionPoint(op);
+      auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
+                                                   outputTy, inputOfDefiningOp);
+      op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
+      return;
+    }
+  });
+  // At this point, we will have some dead UnrealizedConversionCastOps. Just
+  // erase them.
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    root->walk([&](UnrealizedConversionCastOp op) {
+      // Skip existing casts.
+      if (existingCasts.contains(op))
+        return;
+      if (op.use_empty()) {
+        op.erase();
+        changed = true;
+      }
+    });
+  }
+}
+
+void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
+    TypeConverter &typeConverter) {
+  // Any type other than TensorDescType and VectorType are legal as is.
+  typeConverter.addConversion([](Type type) -> std::optional<Type> {
+    if (!isa<TensorDescType, VectorType>(type))
+      return type;
+    return std::nullopt;
+  });
+  // For TensorDescType, drop the layout attribute if any.
+  typeConverter.addConversion([](TensorDescType type) -> Type {
+    if (type.getLayoutAttr()) {
+      return type.dropLayouts();
+    }
+    return type;
+  });
+  // For VectorType, check if there is a distribute layout attribute on the
+  // value. If so, convert to the distributed vector type based on the layout.
+  typeConverter.addConversion([](Value v) -> std::optional<Type> {
+    auto type = v.getType();
+    // If value is not vector type, nothing to do.
+    if (!isa<VectorType>(type))
+      return std::nullopt;
+    auto layout = xegpu::getDistributeLayoutAttr(v);
+    if (!layout || !layout.isForSubgroup())
+      return type;
+    // Vector type is distributed based on lane layout.
+    auto newTyOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
+    if (failed(newTyOrFailure))
+      return type;
+    return *newTyOrFailure;
+  });
+}
+
+void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
+    TypeConverter &typeConverter, RewritePatternSet &patterns,
+    ConversionTarget &target) {
+  populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
+  // CreateNdDescOp is legal only if its result type has no layout attribute.
+  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
+      [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
+  // Any anchor XeGPU op is legal only if it has no anchor layout.
+  target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
+    auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
+    if (!anchorOp)
+      return true;
+    return !anchorOp.getAnchorLayout();
+  });
+  // Arith constants are legal only if they have no temporary layout attribute.
+  target.addDynamicallyLegalOp<arith::ConstantOp>(
+      [=](arith::ConstantOp op) -> bool {
+        // If the result type is not a vector, it's legal.
+        if (!isa<VectorType>(op.getResult().getType()))
+          return true;
+        return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+      });
+  // In math and arith dialects, only handle elementwise ops with a single
+  // result and with a result layout attribute.
+  target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
+      [=](Operation *op) -> std::optional<bool> {
+        // Only handle elementwise mappable ops
+        if (!OpTrait::hasElementwiseMappableTraits(op))
+          return true;
+        // Only handle ops with single vector result
+        if (op->getNumResults() != 1)
+          return true;
+
+        VectorType resultType =
+            dyn_cast<VectorType>(op->getResult(0).getType());
+        if (!resultType)
+          return true;
+
+        // Check if all operands are vectors of the same shape
+        for (Value operand : op->getOperands()) {
+          VectorType operandType = dyn_cast<VectorType>(operand.getType());
+          if (!operandType || operandType.getShape() != resultType.getShape()) {
+            return true;
+          }
+        }
+        return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
+      });
+  // vector::ReductionOp is legal only if its source has no distribute layout
+  // attribute.
+  target.addDynamicallyLegalOp<vector::ReductionOp>(
+      [=](vector::ReductionOp op) -> bool {
+        auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
+        return !layout;
+      });
+  // vector::MultiDimReductionOp op legality.
+  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
+      [=](vector::MultiDimReductionOp op) -> bool {
+        return !isValidSubgroupMultiReductionOp(op);
+      });
+  target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
+                               vector::TransposeOp, vector::BitCastOp,
+                               vector::ShapeCastOp, vector::StepOp,
+                               vector::BroadcastOp>([=](Operation *op) -> bool {
+    return !xegpu::getTemporaryLayout(op->getOpResult(0));
+  });
+  target.addDynamicallyLegalOp<vector::ExtractOp>(
+      [=](vector::ExtractOp op) -> bool {
+        if (!isa<VectorType>(op.getType()))
+          return true;
+        return !xegpu::getTemporaryLayout(op->getOpResult(0));
+      });
+  target.addDynamicallyLegalOp<vector::InsertOp>(
+      [=](vector::InsertOp op) -> bool {
+        return !xegpu::getTemporaryLayout(op->getOpResult(0));
+      });
+  target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
+      [=](vector::ExtractStridedSliceOp op) -> bool {
+        return !xegpu::getTemporaryLayout(op->getOpResult(0));
+      });
+  target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
+      [=](vector::InsertStridedSliceOp op) -> bool {
+        return !xegpu::getTemporaryLayout(op->getOpResult(0));
+      });
+  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+  patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
+               SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
+               SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
+               SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
+               SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
+               SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
+               SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
+               SgToWiVectorShapeCast, SgToWiBroadcast,
+               SgToWiCreateMask<vector::CreateMaskOp>,
+               SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
+                                                         patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index c029ee1d8ae0d..6ef0a63926c8d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -1610,7 +1610,6 @@ void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
       }
     });
   }
-  // Remove layout attributes from SCF ops
   getOperation()->walk([](Operation *op) {
     SmallVector<StringAttr> attrsToRemove;
     for (auto namedAttr : op->getDiscardableAttrs()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 012d7aefafb06..d16d0cfc6c587 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -825,10 +825,12 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
       }
     }
 
-    auto layoutPayload = storeScatterOp.getLayoutAttr();
+    auto layoutPayload =
+        xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(0));
     auto layoutOffsets =
-        xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
-    auto layoutMask = layoutOffsets;
+        xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(2));
+    auto layoutMask =
+        xegpu::getTemporaryLayout(storeScatterOp->getOpOperand(3));
 
     FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
@@ -1130,36 +1132,9 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
       }
     }
 
-    auto layoutPayload = loadGatherOp.getLayoutAttr();
     auto layoutOffsets =
-        xegpu::inferMaskOffsetLayoutForScatterIO(layoutPayload, chunkSize);
-    auto layoutMask = layoutOffsets;
-
-    // print the layouts for debug
-    LLVM_DEBUG({
-      llvm::dbgs() << "In LoadDistribution pattern:\n";
-      llvm::dbgs() << "Payload layout: ";
-      if (layoutPayload)
-        llvm::dbgs() << layoutPayload;
-      else
-        llvm::dbgs() << "none";
-      llvm::dbgs() << "\nOffsets layout: ";
-      if (layoutOffsets)
-        llvm::dbgs() << layoutOffsets;
-      else
-        llvm::dbgs() << "none";
-      llvm::dbgs() << "\nMask layout: ";
-      if (layoutMask)
-        llvm::dbgs() << layoutMask;
-      else
-        llvm::dbgs() << "none";
-      llvm::dbgs() << "\n";
-    });
-
-    // auto layoutOffsets =
-    //     xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
-    // auto layoutMask =
-    // xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
+        xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
+    auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
 
     FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
@@ -2306,8 +2281,6 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       op->erase();
     return WalkResult::advance();
   });
-
-  // Remove layout attributes from SCF ops
   getOperation()->walk([](Operation *op) {
     SmallVector<StringAttr> attrsToRemove;
     for (auto namedAttr : op->getDiscardableAttrs()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 02f88828f667f..dabdcb61f0500 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1784,7 +1784,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
           applyPartialConversion(getOperation(), target, std::move(patterns))))
     return signalPassFailure();
 
-  // Remove layout attributes from SCF ops
   getOperation()->walk([](Operation *op) {
     SmallVector<StringAttr> attrsToRemove;
     for (auto namedAttr : op->getDiscardableAttrs()) {
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 0dab06d206ceb..a6f621de0dd82 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -16,7 +16,6 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"

>From 79764d19d6f664e3b46a5cabec803be021e64dac Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 21:11:24 +0000
Subject: [PATCH 15/21] cleanup

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 200 ++----------------
 1 file changed, 22 insertions(+), 178 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index c326e6c917df8..8f319bd161798 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -30,10 +30,6 @@
 #include <cstdint>
 #include <numeric>
 
-#include "llvm/Support/Debug.h"
-#define DEBUG_TYPE "xegpu-layout-recovery"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
 using namespace mlir;
 
 void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
@@ -139,15 +135,6 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
   xegpu::DistributeLayoutAttr layout = nullptr;
   for (OpOperand &use : result.getUses()) {
     if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
-      // debug print the use and op, and the tmpLayout
-      LLVM_DEBUG({
-        DBGS() << "getLayoutFromUsePoints  use: " << use.getOwner()->getName()
-               << use.getOwner();
-        llvm::dbgs() << ", tmpLayout=" << tmpLayout << "\n";
-      });
-      // under debug mode, we want to check all the use points to make sure
-      // there is no conflict, so we do not break here. In release mode, we can
-      // break at the first use
       if (!layout)
         layout = tmpLayout;
     }
@@ -158,14 +145,8 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
 // For regular operations: First the result layouts are propagated from uses.
 // Then the result layouts are propagated to uses (operands).
 static void propagateResultsToRegularOperands(Operation *op) {
-  LLVM_DEBUG(DBGS() << "propagateResultsToRegularOperands: " << op->getName()
-                    << " (" << op->getNumOperands() << " operands, "
-                    << op->getNumResults() << " results)\n");
-
-  if (op->getNumResults() == 0) {
-    LLVM_DEBUG(DBGS() << "  skipping (no results)\n");
+  if (op->getNumResults() == 0)
     return;
-  }
 
   OpResult result = op->getResult(0);
   xegpu::DistributeLayoutAttr resLayout = getLayoutFromUsePoints(result);
@@ -193,47 +174,26 @@ static void propagateResultsToRegularOperands(Operation *op) {
     // Layouts are needed for vector type only.
     xegpu::DistributeLayoutAttr operandLayout =
         xegpu::inferSourceLayoutFromResult(opr, resLayout);
-    if (!isa<VectorType>(opr.get().getType())) {
-      LLVM_DEBUG(DBGS() << "  operand #" << opr.getOperandNumber()
-                        << ": skipped (non-vector type: " << opr.get().getType()
-                        << ")\n");
+    if (!isa<VectorType>(opr.get().getType()))
       continue;
-    }
 
     xegpu::setTemporaryLayout(opr, operandLayout);
-    // debug print op
-    LLVM_DEBUG(DBGS() << "after propagateResultsToRegularOperands  op: "
-                      << op->getName() << op << "  operand #"
-                      << opr.getOperandNumber()
-                      << ": type=" << opr.get().getType());
-    llvm::dbgs() << ", temp Layout=" << xegpu::getTemporaryLayout(opr);
-    llvm::dbgs() << "\n";
   }
 }
 
 static void propagateRegionResultsToYieldOperands(
     mlir::RegionBranchTerminatorOpInterface yieldOp) {
-  LLVM_DEBUG(DBGS() << "propagateRegionResultsToYieldOperands: "
-                    << yieldOp->getName() << " (" << yieldOp->getNumOperands()
-                    << " operands), parent="
-                    << yieldOp->getParentOp()->getName() << "\n");
-
-  if (isa<func::FuncOp>(yieldOp->getParentOp())) {
-    LLVM_DEBUG(DBGS() << "  skipping (parent is FuncOp)\n");
+  if (isa<func::FuncOp>(yieldOp->getParentOp()))
     return;
-  }
 
   auto regionBranchOp =
       dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
-  if (!regionBranchOp) {
-    LLVM_DEBUG(DBGS() << "  skipping (parent is not RegionBranchOp)\n");
+  if (!regionBranchOp)
     return;
-  }
 
   // Gather layouts for each result of the parent region op from external
   // use points.
   unsigned numResults = regionBranchOp->getNumResults();
-  LLVM_DEBUG(DBGS() << "  parent op has " << numResults << " results\n");
   if (numResults == 0)
     return;
 
@@ -241,14 +201,8 @@ static void propagateRegionResultsToYieldOperands(
   for (unsigned i = 0; i < numResults; ++i) {
     OpResult result = regionBranchOp->getResult(i);
     resultLayouts[i] = getLayoutFromUsePoints(result);
-    if (resultLayouts[i]) {
-      LLVM_DEBUG(DBGS() << "  result #" << i << ": type=" << result.getType()
-                        << ", layout=" << resultLayouts[i] << "\n");
+    if (resultLayouts[i])
       xegpu::setTemporaryLayout(result, resultLayouts[i]);
-    } else {
-      LLVM_DEBUG(DBGS() << "  result #" << i
-                        << ": skipped (no layout from use points)\n");
-    }
   }
 
   // Use getSuccessorOperands to find which operands of the terminator
@@ -264,35 +218,15 @@ static void propagateRegionResultsToYieldOperands(
   unsigned beginIdx = succOps.getBeginOperandIndex();
   unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
 
-  LLVM_DEBUG(DBGS() << "  " << count << " successor operands starting at index "
-                    << beginIdx << "\n");
-
   for (unsigned i = 0; i < count; ++i) {
     if (!resultLayouts[i])
       continue;
-    LLVM_DEBUG(DBGS() << "    -> setting layout on operand #" << (beginIdx + i)
-                      << "\n");
     xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i),
                               resultLayouts[i]);
   }
-
-  LLVM_DEBUG({
-    DBGS() << " after propagateRegionResultsToYieldOperands:\n";
-    yieldOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-    llvm::dbgs() << "\n";
-  });
 }
 
 static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
-  LLVM_DEBUG(DBGS() << "propagateRegionArgsToInits: " << regionOp->getName()
-                    << " (" << regionOp->getNumOperands() << " operands, "
-                    << regionOp->getNumRegions() << " regions)\n");
-  LLVM_DEBUG({
-    DBGS() << " before propagateRegionArgsToInits, Region IR:\n";
-    regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-    llvm::dbgs() << "\n";
-  });
-
   // Iterate all regions of the region op. For each block argument that has a
   // layout (determined from its use points), trace back to find the
   // corresponding init operand of the regionOp and set the layout on it.
@@ -302,14 +236,8 @@ static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
     RegionSuccessor regionSuccessor(&region);
     for (auto [argIdx, regionArg] : llvm::enumerate(region.getArguments())) {
       auto layout = getLayoutFromUsePoints(regionArg);
-      if (!layout) {
-        LLVM_DEBUG(DBGS() << "  region #" << region.getRegionNumber()
-                          << " arg #" << argIdx << ": skipped (no layout)\n");
+      if (!layout)
         continue;
-      }
-      LLVM_DEBUG(DBGS() << "  region #" << region.getRegionNumber() << " arg #"
-                        << argIdx << ": type=" << regionArg.getType()
-                        << ", layout=" << layout << "\n");
 
       // Find all predecessor values that flow into this block argument.
       SmallVector<Value> predValues;
@@ -317,42 +245,23 @@ static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
       for (Value predVal : predValues) {
         // Match predecessor value to an operand of the regionOp.
         for (OpOperand &operand : regionOp->getOpOperands()) {
-          if (operand.get() == predVal) {
-            LLVM_DEBUG(DBGS() << "    -> setting layout on init operand #"
-                              << operand.getOperandNumber() << "\n");
+          if (operand.get() == predVal)
             xegpu::setTemporaryLayout(operand, layout);
-          }
         }
       }
     }
   }
-
-  LLVM_DEBUG({
-    DBGS() << " after propagateRegionArgsToInits, Region IR:\n";
-    regionOp.print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-    llvm::dbgs() << "\n";
-  });
 }
 
 bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-  LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts START ===\n");
-
   auto processFunc = [&](Region &body, StringRef funcName) {
-    LLVM_DEBUG(DBGS() << "Processing func: " << funcName << "\n");
     walkRegionBackward(body, [&](Operation *op) {
-      LLVM_DEBUG(DBGS() << "Visiting op: " << op->getName());
       if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
-        // hit the region op after visiting inside region
-        LLVM_DEBUG(DBGS() << "  -> dispatching as RegionBranchOp\n");
         propagateRegionArgsToInits(regionOp);
       } else if (auto yieldOp =
                      dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
-        // yield op inside region op
-        LLVM_DEBUG(DBGS() << "  -> dispatching as YieldOp\n");
         propagateRegionResultsToYieldOperands(yieldOp);
       } else if (!dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
-        // if the op is regular op, calling propagateResultsToRegularOperands
-        LLVM_DEBUG(DBGS() << "  -> dispatching as regular op\n");
         propagateResultsToRegularOperands(op);
       }
     });
@@ -365,13 +274,6 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
     processFunc(func.getBody(), func.getName());
   });
 
-  LLVM_DEBUG(DBGS() << "=== recoverTemporaryLayouts END ===\n");
-  // print the root op after
-  LLVM_DEBUG({
-    DBGS() << "After recoverTemporaryLayouts, IR:\n";
-    rootOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
-    llvm::dbgs() << "\n";
-  });
   return true;
 }
 
@@ -1394,124 +1296,72 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
 xegpu::DistributeLayoutAttr
 xegpu::inferSourceLayoutFromResult(OpOperand &operand,
                                    xegpu::DistributeLayoutAttr resLayout) {
-  if (!resLayout) {
-    LLVM_DEBUG(DBGS() << "no resLayout, returning null\n");
+  if (!resLayout)
     return xegpu::DistributeLayoutAttr();
-  }
   Operation *op = operand.getOwner();
   unsigned idx = operand.getOperandNumber();
 
   // For vector::BroadcastOp, infer the source layout from the result layout.
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> BroadcastOp\n");
     auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
-    if (!srcTy) {
-      LLVM_DEBUG(DBGS() << "     source is not VectorType, returning null\n");
+    if (!srcTy)
       return xegpu::DistributeLayoutAttr();
-    }
-    auto inferred = xegpu::inferBroadcastSourceLayout(
+    return xegpu::inferBroadcastSourceLayout(
         resLayout, broadcast.getResultVectorType().getShape(),
         srcTy.getShape());
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
   }
 
   // For vector::MultiDimReductionOp, infer source layout from result layout
   // using reduction dims. Acc operand is expected to have the same layout as
   // the result.
   if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> MultiDimReductionOp, operand idx=" << idx
-                      << "\n");
     if (idx == 0) {
       SmallVector<int64_t> reductionDims(reduction.getReductionDims());
-      LLVM_DEBUG({
-        DBGS() << "     reductionDims=[";
-        llvm::interleaveComma(reductionDims, llvm::dbgs());
-        llvm::dbgs() << "]\n";
-      });
-      auto inferred =
-          xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
-      LLVM_DEBUG(DBGS() << "     inferred source layout=" << inferred << "\n");
-      return inferred;
+      return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
     }
-    if (idx == 1) {
-      LLVM_DEBUG(DBGS() << "     acc operand, using resLayout\n");
+    if (idx == 1)
       return resLayout;
-    }
   }
 
-  if (auto reduction = dyn_cast<vector::ReductionOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> ReductionOp\n");
-    auto inferred = xegpu::inferReductionSourceLayout(resLayout);
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
-  }
+  if (auto reduction = dyn_cast<vector::ReductionOp>(op))
+    return xegpu::inferReductionSourceLayout(resLayout);
 
   // For vector::BitCastOp, infer source layout from result layout using
   // element type bitwidths.
   if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> BitCastOp\n");
     int resElemBitWidth =
         bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
     int srcElemBitWidth =
         bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
-    LLVM_DEBUG(DBGS() << "     resBitWidth=" << resElemBitWidth
-                      << ", srcBitWidth=" << srcElemBitWidth << "\n");
-    auto inferred = xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
-                                                    srcElemBitWidth);
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
+    return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+                                           srcElemBitWidth);
   }
 
   // For vector::ShapeCastOp, infer source layout from result layout using
   // shapes.
   if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
-    LLVM_DEBUG({
-      DBGS() << "  -> ShapeCastOp: resShape=[";
-      llvm::interleaveComma(shapeCast.getResultVectorType().getShape(),
-                            llvm::dbgs());
-      llvm::dbgs() << "], srcShape=[";
-      llvm::interleaveComma(shapeCast.getSourceVectorType().getShape(),
-                            llvm::dbgs());
-      llvm::dbgs() << "]\n";
-    });
-    auto inferred = xegpu::inferShapeCastSourceLayout(
+    return xegpu::inferShapeCastSourceLayout(
         resLayout, shapeCast.getResultVectorType().getShape(),
         shapeCast.getSourceVectorType().getShape());
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
   }
 
   // For vector::InsertStridedSliceOp, infer source layout from result layout.
   // Dest vector must have the same layout as the result.
   if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
-    LLVM_DEBUG(DBGS() << "  -> InsertStridedSliceOp, operand idx=" << idx
-                      << "\n");
     if (idx == 0) {
-      auto inferred = xegpu::inferInsertStridedSliceSourceLayout(
+      return xegpu::inferInsertStridedSliceSourceLayout(
           resLayout, insertSlice.getDestVectorType().getShape(),
           insertSlice.getSourceVectorType().getShape());
-      LLVM_DEBUG(DBGS() << "     inferred source layout=" << inferred << "\n");
-      return inferred;
     }
-    if (idx == 1) {
-      LLVM_DEBUG(DBGS() << "     dest operand, using resLayout\n");
+    if (idx == 1)
       return resLayout;
-    }
   }
 
   // For vector::TransposeOp, infer source layout from result layout using
   // permutation.
   if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
-    LLVM_DEBUG({
-      DBGS() << "  -> TransposeOp, perm=[";
-      llvm::interleaveComma(transpose.getPermutation(), llvm::dbgs());
-      llvm::dbgs() << "]\n";
-    });
-    auto inferred = xegpu::inferTransposeSourceLayout(
-        resLayout, transpose.getPermutation());
-    LLVM_DEBUG(DBGS() << "     inferred=" << inferred << "\n");
-    return inferred;
+    return xegpu::inferTransposeSourceLayout(resLayout,
+                                             transpose.getPermutation());
   }
 
   if (isa<VectorType>(operand.get().getType()) &&
@@ -1520,8 +1370,6 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
     // result.
     // if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() ==
     // 1) {
-    LLVM_DEBUG(DBGS() << "  -> other vector or tensorDesc ops using resLayout="
-                      << (resLayout ? resLayout : nullptr) << "\n");
     return resLayout;
   }
   return xegpu::DistributeLayoutAttr();
@@ -1537,9 +1385,5 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
     return inferredOperandLayout;
   // By default, assume no layout conflict and return the current layout of
   // the operand.
-  auto fallback = xegpu::getDistributeLayoutAttr(operand.get());
-  LLVM_DEBUG(DBGS() << "  -> fallback (unhandled op " << op->getName()
-                    << "), returning operand layout="
-                    << (fallback ? fallback : nullptr) << "\n");
-  return fallback;
+  return xegpu::getDistributeLayoutAttr(operand.get());
 }

>From eb6048ed52046e0aa13d566e34d49ec22e1a13ac Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 21:14:53 +0000
Subject: [PATCH 16/21] clean up

---
 .../XeGPUSgToWiDistributeExperimental.cp      | 1744 -----------------
 1 file changed, 1744 deletions(-)
 delete mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
deleted file mode 100644
index e3227c7f5b149..0000000000000
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cp
+++ /dev/null
@@ -1,1744 +0,0 @@
-//===- XeGPUSgToWiDistributeExperimental.cpp - XeGPU SG to WI Pass --------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Index/IR/IndexDialect.h"
-#include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/Transforms/Patterns.h"
-#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
-#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
-#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "llvm/ADT/SetVector.h"
-#include "llvm/Support/LogicalResult.h"
-#include "llvm/Support/raw_ostream.h"
-#include <optional>
-
-namespace mlir {
-namespace xegpu {
-#define GEN_PASS_DEF_XEGPUSGTOWIDISTRIBUTEEXPERIMENTAL
-#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
-} // namespace xegpu
-} // namespace mlir
-
-using namespace mlir;
-
-#define DEBUG_TYPE "xegpu-sg-to-wi-distribute-experimental"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-
-namespace {
-
-/// Casts the given vector value `v` to the expected vector type `expectedTy`.
-static Value castValueTo(ConversionPatternRewriter &rewriter,
-                         TypedValue<VectorType> v, VectorType expectedTy) {
-  // If the type matches, simply return the value itself.
-  if (v.getType() == expectedTy)
-    return v;
-  // If only shape differs, use shape cast.
-  if (isa<VectorType>(v.getType()) &&
-      v.getType().getNumElements() == expectedTy.getNumElements())
-    return vector::ShapeCastOp::create(rewriter, v.getLoc(), expectedTy, v);
-
-  // Else create an unrealized cast.
-  auto newOp = UnrealizedConversionCastOp::create(rewriter, v.getLoc(),
-                                                  expectedTy, ValueRange{v});
-  return newOp.getResult(0);
-}
-
-/// A vector::MultiDimReductionOp at subgroup level in expected form if, it has
-/// exactly 1 reduction dimension, it had valid result layout attribute, and
-/// result type can be distributed to lanes using the layout.
-static bool isValidSubgroupMultiReductionOp(vector::MultiDimReductionOp op) {
-  auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
-  // If no layout, not valid.
-  if (!resLayout || !resLayout.isForSubgroup())
-    return false;
-  // Scalar result (e.g., vector<32xf32> to f32) is valid.
-  if (op.getType().isIntOrFloat())
-    return op.getReductionDims().size() == 1;
-  VectorType resTy = dyn_cast<VectorType>(op.getType());
-  if (!resTy)
-    return false;
-  // Compute the distributed result vector type based on the layout.
-  FailureOr<VectorType> resDistTypeOrFailure =
-      getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
-  if (failed(resDistTypeOrFailure))
-    return false;
-  return op.getReductionDims().size() == 1;
-}
-
-/// A vector::MultiDimReductionOp is doing lane-local reduction if each workitem
-/// is doing its own local reduction. In this case the result layout ensures
-/// that result vector is distributed to lanes, i.e. the result vector type is
-/// different from the distributed result vector type.
-static bool isReductionLaneLocal(vector::MultiDimReductionOp op) {
-  // Must be valid MultiDimReductionOp.
-  assert(isValidSubgroupMultiReductionOp(op) && "Expecting a valid subgroup "
-                                                "MultiDimReductionOp");
-  auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
-  VectorType resTy = dyn_cast<VectorType>(op.getType());
-  auto resDistTypeOrFailure = getDistVecTypeBasedOnLaneLayout(resLayout, resTy);
-  return resTy != resDistTypeOrFailure.value();
-}
-
-/// Given a vector type and its distributed vector type, return the list of
-/// dimensions that are distributed.
-static SmallVector<int64_t> getDistributedDims(VectorType originalType,
-                                               VectorType distributedType) {
-  assert(originalType.getRank() == distributedType.getRank() &&
-         "original and distributed vector types must have the same rank");
-  SmallVector<int64_t> distributedDims;
-  for (int64_t i = 0; i < originalType.getRank(); ++i) {
-    if (distributedType.getDimSize(i) != originalType.getDimSize(i))
-      distributedDims.push_back(i);
-  }
-  return distributedDims;
-}
-
-/// Distributes a subgroup-level CreateNdDesc op to workitem-level CreateNdDesc
-/// op. This simply drops the layout attribute from the tensor descriptor type.
-struct SgToWiCreateNdDesc : public OpConversionPattern<xegpu::CreateNdDescOp> {
-  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::CreateNdDescOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::TensorDescType resultType = op.getType();
-    // If no layout, nothing to do.
-    if (!resultType.getLayout())
-      return failure();
-
-    auto newOp = xegpu::CreateNdDescOp::create(
-        rewriter, op.getLoc(), resultType.dropLayouts(), op.getOperands(),
-        op->getAttrs());
-    rewriter.replaceOp(op, newOp.getResult());
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level LoadNd op to workitem-level LoadNd op. Output
-/// of workitem-level LoadNd op is 1D. ShapeCast is added to restore the
-/// original rank.
-struct SgToWiLoadNd : public OpConversionPattern<xegpu::LoadNdOp> {
-  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::LoadNdOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
-    // If no layout, nothing to do.
-    if (!layout)
-      return failure();
-    // Check if the layout attached to the tensor descriptor is same as the
-    // anchor layout. Otherwise, this is a conflict.
-    if (op.getTensorDescType().getLayout() != layout)
-      return rewriter.notifyMatchFailure(
-          op, "conflicting layout attributes on tensor descriptor and anchor");
-    auto uArch = getUArch(xegpu::getChipStr(op).value_or(""));
-    if (!uArch)
-      return rewriter.notifyMatchFailure(
-          op, "xegpu::LoadNdOp require target attribute attached to "
-              "determine transpose "
-              "requirement");
-    auto supportedWiResultTyOrFailure =
-        xegpu::getDistributedVectorType(op.getTensorDescType());
-    auto expectedWiResultTyOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(layout, op.getType());
-    if (failed(supportedWiResultTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "unable to compute the workitem vector type for LoadNdOp");
-    if (failed(expectedWiResultTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op,
-          "unable to compute expected workitem vector type from lane layout");
-    auto newOp = xegpu::LoadNdOp::create(
-        rewriter, op.getLoc(), supportedWiResultTyOrFailure.value(),
-        adaptor.getTensorDesc(), op.getMixedOffsets(), op.getPackedAttr(),
-        op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
-        op.getL3HintAttr(), /**layout**/ nullptr);
-    // Set the packed attribute if the layout requires it.
-    newOp.setPacked(xegpu::requirePacked(cast<xegpu::LayoutAttr>(layout)));
-    // Set the transpose attribute if the layout requires it.
-    if (xegpu::requireTranspose(cast<xegpu::LayoutAttr>(layout), uArch))
-      newOp.setTranspose(DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0}));
-    rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
-                                       expectedWiResultTyOrFailure.value()));
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level StoreNd op to workitem-level StoreNd op. Stored
-/// value in workitem-level StoreNd op is 1D. ShapeCast is added to cast the
-/// incoming value to 1D.
-struct SgToWiStoreNd : public OpConversionPattern<xegpu::StoreNdOp> {
-  using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::StoreNdOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
-    // If no layout, nothing to do.
-    if (!layout)
-      return failure();
-    // Check if the layout attached to the tensor descriptor and value layout is
-    // same as the anchor layout. Otherwise, this is a conflict.
-    if (op.getTensorDescType().getLayout() != layout)
-      return rewriter.notifyMatchFailure(
-          op, "conflicting layout attributes on tensor descriptor and anchor");
-    auto valueLayout = xegpu::getDistributeLayoutAttr(op->getOpOperand(0));
-    if (valueLayout != layout)
-      return rewriter.notifyMatchFailure(
-          op, "conflicting layout attributes on value and anchor");
-    auto supportedWiValueTyOrFailure =
-        xegpu::getDistributedVectorType(op.getTensorDescType());
-    if (failed(supportedWiValueTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op,
-          "unable to compute wi vector type for StoreNdOp value from tensor "
-          "descriptor");
-
-    xegpu::StoreNdOp::create(
-        rewriter, op.getLoc(),
-        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getValue()),
-                    supportedWiValueTyOrFailure.value()),
-        adaptor.getTensorDesc(), op.getMixedOffsets(), op.getL1HintAttr(),
-        op.getL2HintAttr(), op.getL3HintAttr(), /**layout**/ nullptr);
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level Dpas op to workitem-level Dpas op. All inpputs
-/// and output of workitem-level Dpas op are 1D. Necessary casts are added to
-/// convert the inputs and output to/from 1D.
-struct SgToWiDpas : public OpConversionPattern<xegpu::DpasOp> {
-  using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::DpasOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // llvm::errs() << "DpasOpPattern matchAndRewrite called\n";
-    // Check if the op has A, B and CD layouts attached.
-    auto layoutA = cast<xegpu::LayoutAttr>(op.getLayoutAAttr());
-    auto layoutB = cast<xegpu::LayoutAttr>(op.getLayoutBAttr());
-    auto layoutCd = cast<xegpu::LayoutAttr>(op.getLayoutCdAttr());
-    if (!layoutA || !layoutB || !layoutCd)
-      return failure();
-    // llvm::errs() << "tryning to calculate wi types for dpas op\n";
-    auto wiResultTyOrFailure =
-        xegpu::getDistributedVectorType(op.getType(), layoutCd);
-    auto wiATypeOrFailure =
-        xegpu::getDistributedVectorType(op.getLhs().getType(), layoutA);
-    auto wiBTypeOrFailure =
-        xegpu::getDistributedVectorType(op.getRhs().getType(), layoutB);
-    auto expectedWiResultTyOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(layoutCd, op.getType());
-    if (failed(wiResultTyOrFailure) || failed(wiATypeOrFailure) ||
-        failed(wiBTypeOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "failed to calculate supported workitem vector types for DpasOp "
-              "from layouts");
-    if (failed(expectedWiResultTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "unable to compute expected workitem vector type for DpasOp from "
-              "lane layout");
-    auto newOp = xegpu::DpasOp::create(
-        rewriter, op->getLoc(), wiResultTyOrFailure.value(),
-        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getLhs()),
-                    wiATypeOrFailure.value()),
-        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getRhs()),
-                    wiBTypeOrFailure.value()),
-        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getAcc()),
-                    wiResultTyOrFailure.value()),
-        /** layoutA**/ nullptr,
-        /** layoutB**/ nullptr, /** layoutCd**/ nullptr);
-    // Explicitly set the new types to enable correct type materializations.
-    rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
-                                       expectedWiResultTyOrFailure.value()));
-    return success();
-  }
-};
-
-/// Distributes elementwise ops to workitem-level elementwise ops. This
-/// currently handles elementwise ops with single result only.
-struct SgToWiElementWise : public ConversionPattern {
-  SgToWiElementWise(TypeConverter &typeConverter, MLIRContext *ctx)
-      : ConversionPattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {}
-
-  LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Only match ops with elementwise trait and single result.
-    if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
-      return failure();
-
-    auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
-    if (!resultType)
-      return rewriter.notifyMatchFailure(
-          op, "operation result is not a vector type");
-
-    xegpu::DistributeLayoutAttr layout =
-        xegpu::getTemporaryLayout(llvm::cast<OpResult>(op->getResult(0)));
-    if (!layout || !layout.isForSubgroup())
-      return rewriter.notifyMatchFailure(
-          op, "operation result does not have subgroup distribute layout");
-
-    auto wiShapeOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
-
-    if (failed(wiShapeOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "unable to compute workitem vector type from the layout");
-
-    VectorType newResultType = wiShapeOrFailure.value();
-    OperationState state(op->getLoc(), op->getName());
-    state.addOperands(operands);
-    state.addTypes(newResultType);
-    // Copy all attributes except for DistributeLayoutAttr.
-    for (auto attr : op->getAttrs()) {
-      if (!isa<xegpu::DistributeLayoutAttr>(attr.getValue()))
-        state.addAttribute(attr.getName(), attr.getValue());
-    }
-    Operation *newOp = rewriter.create(state);
-
-    rewriter.replaceOp(op, newOp->getResult(0));
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level arith ConstantOp to workitem-level arith
-/// ConstantOp.
-struct SgToWiArithConstant : public OpConversionPattern<arith::ConstantOp> {
-  using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto resultType = dyn_cast<VectorType>(op.getType());
-    if (!resultType)
-      return failure();
-
-    // Only handle dense vector constants
-    auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
-    if (!dense)
-      return rewriter.notifyMatchFailure(
-          op, "only dense splat vector constants are supported");
-
-    xegpu::DistributeLayoutAttr layout =
-        xegpu::getTemporaryLayout(llvm::cast<OpResult>(op.getResult()));
-    if (!layout || !layout.isForSubgroup())
-      return rewriter.notifyMatchFailure(
-          op, "operation result does not have subgroup distribute layout");
-
-    auto wiShapeOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultType);
-
-    if (failed(wiShapeOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "unable to compute workitem vector type from the layout");
-
-    VectorType newResultType = wiShapeOrFailure.value();
-    auto sclarValue = dense.getSplatValue<Attribute>();
-    auto newDenseAttr = DenseElementsAttr::get(newResultType, sclarValue);
-
-    auto newOp = arith::ConstantOp::create(rewriter, op.getLoc(), newResultType,
-                                           newDenseAttr);
-    rewriter.replaceOp(op, newOp.getResult());
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level PrefetchNd op to workitem-level PrefetchNd op.
-struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
-  using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::PrefetchNdOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
-    // If no layout, nothing to do.
-    if (!layout)
-      return failure();
-
-    xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), adaptor.getTensorDesc(),
-                                op.getMixedOffsets(), op.getL1HintAttr(),
-                                op.getL2HintAttr(), op.getL3HintAttr(),
-                                /**layout**/ nullptr);
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
-///
-/// Example 1 (1D, no chunk size):
-///   layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-///   %mask = producer_op : vector<16xi1>
-///   %offset = producer_op : vector<16xindex>
-///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
-///     vector<16xindex>, vector<16xi1> -> vector<16xf16>
-/// Distributed to:
-///   %mask = producer_op : vector<1xi1>
-///   %offset = producer_op : vector<1xindex>
-///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
-///     vector<1xindex>, vector<1xi1> -> vector<1xf16>
-///
-/// Example 2 (2D with chunk size, same mask & offset):
-///   layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
-///   %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
-///     memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-/// Distributed to:
-///   %0 = xegpu.load %src[%offset], %mask <{chunk_size=8}> :
-///     memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-///
-/// Example 3 (3D with leading unit dims):
-///   layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
-///   %mask = producer_op : vector<1x1x16xi1>
-///   %offset = producer_op : vector<1x1x16xindex>
-///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
-///     vector<1x1x16xindex>, vector<1x1x16xi1> -> vector<1x1x16xf16>
-/// Distributed to:
-///   %mask = producer_op : vector<1x1x1xi1>
-///   %offset = producer_op : vector<1x1x1xindex>
-///   %0 = xegpu.load %src[%offset], %mask : memref<256xf16>,
-///     vector<1xindex>, vector<1xi1> -> vector<1xf16>
-struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
-  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
-    if (!layout)
-      return failure();
-
-    VectorType origResultTy = op.getValueType();
-    if (!origResultTy)
-      return failure();
-
-    // Check that leading dimensions are unit.
-    int chunkSize = op.getChunkSize().value_or(1);
-    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
-    ArrayRef<int64_t> shape = origResultTy.getShape();
-    if (llvm::any_of(
-            shape.take_front(origResultTy.getRank() - effectiveVecRank),
-            [](int64_t d) { return d != 1; }))
-      return rewriter.notifyMatchFailure(
-          op, "Only unit dimensions allowed for the leading "
-              "dimensions of the load vector!");
-
-    auto distResultTyOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(layout, origResultTy);
-    if (failed(distResultTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op,
-          "unable to compute expected workitem vector type from lane layout");
-
-    VectorType distResultTy = distResultTyOrFailure.value();
-    VectorType distResultTy1D = VectorType::get({distResultTy.getNumElements()},
-                                                distResultTy.getElementType());
-
-    // Flatten offsets and mask to 1D to match the 1D result type.
-    Value distOffsets = adaptor.getOffsets();
-    auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
-    VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
-                                             distOffsetsTy.getElementType());
-    distOffsets = castValueTo(
-        rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
-
-    Value distMask = adaptor.getMask();
-    auto distMaskTy = cast<VectorType>(distMask.getType());
-    VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
-                                          distMaskTy.getElementType());
-    distMask =
-        castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
-
-    Value distSource = adaptor.getSource();
-    auto newOp = xegpu::LoadGatherOp::create(
-        rewriter, op.getLoc(), distResultTy1D, distSource, distOffsets,
-        distMask, op.getChunkSizeAttr(), op.getL1HintAttr(), op.getL2HintAttr(),
-        op.getL3HintAttr(), /*layout=*/nullptr);
-
-    Value result = newOp->getResult(0);
-    if (distResultTy1D != distResultTy)
-      result = castValueTo(rewriter, cast<TypedValue<VectorType>>(result),
-                           distResultTy);
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// This pattern distributes a subgroup-level vector.reduction op to
-/// workitem-level. This require shuffling the data across the workitems (using
-/// gpu::ShuffleOp) and reducing in stages until all workitems have the final
-/// result.
-struct SgToWiVectorReduction : public OpConversionPattern<vector::ReductionOp> {
-  using OpConversionPattern<vector::ReductionOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::ReductionOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
-
-    // If no layout, nothing to do.
-    if (!layout || !layout.isForSubgroup())
-      return failure();
-
-    VectorType srcVecType = op.getSourceVectorType();
-    // Only rank 1 vectors supported.
-    if (srcVecType.getRank() != 1)
-      return rewriter.notifyMatchFailure(
-          op, "Only rank 1 reductions can be distributed.");
-    // Lane layout must have the same rank as the vector.
-    if (layout.getRank() != srcVecType.getRank())
-      return rewriter.notifyMatchFailure(
-          op, "Layout rank does not match vector rank.");
-
-    // Get the subgroup size from the layout.
-    int64_t sgSize = layout.getEffectiveLaneLayoutAsInt()[0];
-    const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
-    if (!uArch)
-      return rewriter.notifyMatchFailure(
-          op, "xegpu::ReductionOp require target attribute attached to "
-              "determine subgroup size");
-
-    // Only subgroup-sized vectors supported.
-    if (sgSize != uArch->getSubgroupSize() ||
-        srcVecType.getShape()[0] % sgSize != 0)
-      return rewriter.notifyMatchFailure(op,
-                                         "Invalid layout or reduction vector "
-                                         "dimension must match subgroup size.");
-
-    if (!op.getType().isIntOrFloat())
-      return rewriter.notifyMatchFailure(
-          op, "Reduction distribution currently only supports floats and "
-              "integer types.");
-
-    // Get the distributed vector (per work-item portion).
-    Value laneValVec = adaptor.getVector();
-
-    // Distribute and reduce across work-items in the subgroup.
-    Value fullReduce = xegpu::subgroupReduction(
-        op.getLoc(), rewriter, laneValVec, op.getKind(), sgSize);
-
-    // If there's an accumulator, combine it with the reduced value.
-    if (adaptor.getAcc())
-      fullReduce = vector::makeArithReduction(
-          rewriter, op.getLoc(), op.getKind(), fullReduce, adaptor.getAcc());
-
-    rewriter.replaceOp(op, fullReduce);
-    return success();
-  }
-};
-
-/// This pattern distributes a subgroup-level vector.multi_reduction op to
-/// workitem-level only if the reduction is lane-local. This means that
-/// reduction dimension is not distributed to lanes and each lane does its own
-/// local reduction.
-struct SgToWiMultiDimReduction
-    : public OpConversionPattern<vector::MultiDimReductionOp> {
-  using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::MultiDimReductionOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    Value result;
-    ArrayRef<int64_t> reductionDims = op.getReductionDims();
-    assert(reductionDims.size() == 1 &&
-           "Expecting single reduction dimension for subgroup multi "
-           "reduction op");
-    // For rank > 2, ensure leading dimensions are unit.
-    VectorType sourceType = op.getSourceVectorType();
-    int64_t rank = sourceType.getRank();
-    if (rank > 2) {
-      ArrayRef<int64_t> shape = sourceType.getShape();
-      if (llvm::any_of(shape.take_front(rank - 2),
-                       [](int64_t d) { return d != 1; }))
-        return rewriter.notifyMatchFailure(
-            op, "only unit leading dimensions are supported for "
-                "multi_reduction with rank > 2");
-    }
-    // Handle scalar result: full reduction of a distributed vector to a
-    // scalar. First do a local vector reduction, then cross-lane shuffles.
-    if (op.getType().isIntOrFloat()) {
-      auto reductionDim = reductionDims[0];
-      VectorType origSourceType = op.getSourceVectorType();
-      int64_t reductionDimSize = origSourceType.getShape()[reductionDim];
-      // Local reduction to scalar, then cross-lane butterfly shuffles.
-      result =
-          xegpu::subgroupReduction(op.getLoc(), rewriter, adaptor.getSource(),
-                                   op.getKind(), reductionDimSize);
-      // Combine with accumulator if present.
-      if (adaptor.getAcc())
-        result = vector::makeArithReduction(rewriter, op.getLoc(), op.getKind(),
-                                            result, adaptor.getAcc());
-    } else if (isReductionLaneLocal(op)) {
-      auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
-      VectorType resVecTy = dyn_cast<VectorType>(op.getType());
-      auto resDistVecTyOrFailure =
-          getDistVecTypeBasedOnLaneLayout(resLayout, resVecTy);
-      // For lane local reduction, simply create a new MultiDimReductionOp using
-      // adaptor operands and the new result type.
-      result = vector::MultiDimReductionOp::create(
-          rewriter, op.getLoc(), resDistVecTyOrFailure.value(), op.getKind(),
-          adaptor.getSource(), adaptor.getAcc(), op.getReductionDims());
-    } else {
-      auto reductionDim = reductionDims[0];
-      VectorType sourceType = op.getSourceVectorType();
-      int64_t reductionDimSize = sourceType.getShape()[reductionDim];
-      result = xegpu::lowerCrossLaneReductionToShuffles(
-          cast<TypedValue<VectorType>>(adaptor.getSource()),
-          cast<TypedValue<VectorType>>(adaptor.getAcc()), op.getKind(),
-          reductionDim, reductionDimSize, op.getLoc(), rewriter);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// Helper to compute distributed coordinates for matrix ops.
-/// When not using subgroup_block_io, each workitem computes its own
-/// coordinates based on the layout and lane ID.
-static SmallVector<Value> computeDistributedCoordsForMatrixOp(
-    ConversionPatternRewriter &rewriter, Location loc,
-    xegpu::DistributeLayoutAttr layout, ArrayRef<int64_t> payloadShape,
-    ValueRange origOffsets) {
-  Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
-                                       /*upperBound=*/mlir::IntegerAttr());
-  auto maybeCoords =
-      layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
-  if (failed(maybeCoords))
-    return {};
-  assert(maybeCoords.value().size() == 1 &&
-         "Expected one set of distributed offsets");
-  SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
-      rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
-      getAsOpFoldResult(origOffsets));
-  return llvm::map_to_vector(ofrVec, llvm::CastTo<Value>);
-}
-
-/// This pattern distributes a subgroup-level LoadMatrix op to workitem-level.
-struct SgToWiLoadMatrix : public OpConversionPattern<xegpu::LoadMatrixOp> {
-  using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto layout = op.getLayoutAttr();
-    // If no layout, nothing to do.
-    if (!layout)
-      return failure();
-
-    VectorType sgPayloadTy = dyn_cast<VectorType>(op.getResult().getType());
-    if (!sgPayloadTy)
-      return rewriter.notifyMatchFailure(
-          op, "the matrix op payload must be a vector type");
-
-    auto loc = op.getLoc();
-    auto offsets = op.getMixedOffsets();
-    if (offsets.empty())
-      return rewriter.notifyMatchFailure(op, "the load op must have offsets");
-
-    FailureOr<VectorType> distPayloadTyOrFailure =
-        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
-    if (failed(distPayloadTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "Failed to distribute matrix op payload based on layout.");
-
-    SmallVector<Value> offsetsAsValues =
-        vector::getAsValues(rewriter, loc, offsets);
-
-    SmallVector<Value> newCoords = offsetsAsValues;
-    if (!op.getSubgroupBlockIoAttr()) {
-      newCoords = computeDistributedCoordsForMatrixOp(
-          rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
-      if (newCoords.empty())
-        return rewriter.notifyMatchFailure(
-            op, "Failed to compute distributed coordinates.");
-    }
-
-    SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
-                                         ShapedType::kDynamic);
-    DenseI64ArrayAttr newConstOffsetsAttr =
-        rewriter.getDenseI64ArrayAttr(newConstOffsets);
-
-    auto newOp = xegpu::LoadMatrixOp::create(
-        rewriter, loc, *distPayloadTyOrFailure, adaptor.getMemDesc(),
-        ValueRange(newCoords), newConstOffsetsAttr, op.getSubgroupBlockIoAttr(),
-        xegpu::DistributeLayoutAttr{});
-    rewriter.replaceOp(op, newOp.getResult());
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level vector.transpose op to workitem-level.
-struct SgToWiVectorTranspose : public OpConversionPattern<vector::TransposeOp> {
-  using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::TransposeOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr sourceLayout =
-        xegpu::getTemporaryLayout(op->getOpOperand(0));
-    xegpu::DistributeLayoutAttr resultLayout =
-        xegpu::getTemporaryLayout(op->getOpResult(0));
-    if (!sourceLayout || !resultLayout)
-      return rewriter.notifyMatchFailure(
-          op, "the source or result vector of the transpose op lacks layout "
-              "attribute");
-    ArrayRef<int64_t> perm = op.getPermutation();
-    // Result layout must be a transpose of source layout.
-    if (!resultLayout.isTransposeOf(sourceLayout, perm,
-                                    xegpu::LayoutKind::Lane))
-      return rewriter.notifyMatchFailure(
-          op, "the source or result vector layouts must be transposes of "
-              "each other");
-    FailureOr<VectorType> distributedResultTypeOrFailure =
-        getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
-    if (failed(distributedResultTypeOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "Failed to distribute the result vector type in "
-              "vector::Transpose op");
-    auto newOp = vector::TransposeOp::create(rewriter, op.getLoc(),
-                                             adaptor.getVector(), perm);
-    rewriter.replaceOp(op, castValueTo(rewriter, newOp.getResult(),
-                                       distributedResultTypeOrFailure.value()));
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level vector.bitcast op to workitem-level.
-/// Bitcast only impacts the innermost dimension of the source/result vectors.
-struct SgToWiVectorBitcast : public OpConversionPattern<vector::BitCastOp> {
-  using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::BitCastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr resultLayout =
-        xegpu::getTemporaryLayout(op->getOpResult(0));
-    if (!resultLayout)
-      return rewriter.notifyMatchFailure(
-          op, "result vector of the bitcast op lacks layout attribute");
-    FailureOr<VectorType> distributedResultTypeOrFailure =
-        getDistVecTypeBasedOnLaneLayout(resultLayout, op.getResultVectorType());
-    if (failed(distributedResultTypeOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "Failed to distribute the result vector type in "
-              "vector::BitCast op");
-    auto newOp = vector::BitCastOp::create(
-        rewriter, op.getLoc(), distributedResultTypeOrFailure.value(),
-        adaptor.getSource());
-    rewriter.replaceOp(op, newOp.getResult());
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level vector.create_mask or vector.constant_mask op
-/// to workitem-level. Uses `computeDistributedCoords()` to obtain the
-/// coordinates each workitem owns, then compares each coordinate against the
-/// original mask bounds using `arith.cmpi slt`. The per-element boolean
-/// results are assembled into the distributed mask vector.
-///
-/// For multi-dimensional masks, the element is in-bounds when ALL dimensions
-/// satisfy `coord[i] < bound[i]`.
-///
-/// Example (1D):
-///   layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-///   %mask = vector.create_mask %m0 : vector<16xi1>
-/// For lane k, computeDistributedCoords gives coord = [k], so:
-///   %in_bounds = arith.cmpi slt, %coord, %m0  →  i1
-///   %mask = vector.broadcast %in_bounds : i1 to vector<1xi1>
-///
-/// Example (2D):
-///   layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>
-///   %mask = vector.create_mask %m0, %m1 : vector<8x4xi1>
-/// Each WI owns a 1x2 slice. computeDistributedCoords returns 2 coords:
-///   [[r0, c0], [r0, c1]]
-/// For each coord: in_bounds = (r < m0) && (c < m1)
-///   %mask = vector.from_elements %bit0, %bit1 : vector<1x2xi1>
-template <typename OpType,
-          typename = std::enable_if_t<llvm::is_one_of<
-              OpType, vector::CreateMaskOp, vector::ConstantMaskOp>::value>>
-struct SgToWiCreateMask : public OpConversionPattern<OpType> {
-  using OpConversionPattern<OpType>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr layout =
-        xegpu::getTemporaryLayout(op->getOpResult(0));
-    if (!layout || !layout.isForSubgroup())
-      return rewriter.notifyMatchFailure(
-          op, "operation result does not have subgroup distribute layout");
-
-    VectorType origType = op.getType();
-    FailureOr<VectorType> distTypeOrFailure =
-        getDistVecTypeBasedOnLaneLayout(layout, origType);
-    if (failed(distTypeOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "unable to compute workitem vector type from the layout");
-
-    VectorType distType = distTypeOrFailure.value();
-    Location loc = op.getLoc();
-
-    // Materialize the original mask bounds as Values.
-    SmallVector<Value> origBounds;
-    if constexpr (std::is_same_v<OpType, vector::CreateMaskOp>) {
-      origBounds.append(op.getOperands().begin(), op.getOperands().end());
-    } else {
-      auto dimSizes = op.getMaskDimSizesAttr().asArrayRef();
-      for (auto dimSize : dimSizes)
-        origBounds.push_back(
-            arith::ConstantIndexOp::create(rewriter, loc, dimSize).getResult());
-    }
-
-    ArrayRef<int64_t> origShape = origType.getShape();
-
-    // Use computeDistributedCoords to get the coordinates each WI owns.
-    Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
-                                         /*upperBound=*/mlir::IntegerAttr());
-    auto maybeCoordsVec =
-        layout.computeDistributedCoords(rewriter, loc, laneId, origShape);
-    if (failed(maybeCoordsVec))
-      return rewriter.notifyMatchFailure(
-          op, "failed to compute distributed coordinates from layout");
-
-    SmallVector<SmallVector<Value>> coordsVec = maybeCoordsVec.value();
-    int64_t numElements = distType.getNumElements();
-    assert(static_cast<int64_t>(coordsVec.size()) == numElements &&
-           "number of coordinate sets must match number of distributed "
-           "elements");
-
-    // For each element, compare all coordinates against bounds.
-    Value trueVal =
-        arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, /*width=*/1);
-    SmallVector<Value> maskBits;
-    for (auto &coords : coordsVec) {
-      Value inBounds = trueVal;
-      for (size_t i = 0; i < coords.size(); ++i) {
-        Value cmp = arith::CmpIOp::create(
-            rewriter, loc, arith::CmpIPredicate::slt, coords[i], origBounds[i]);
-        inBounds = arith::AndIOp::create(rewriter, loc, inBounds, cmp);
-      }
-      maskBits.push_back(inBounds);
-    }
-
-    // Build the distributed mask vector.
-    Value result;
-    if (numElements == 1) {
-      result =
-          vector::BroadcastOp::create(rewriter, loc, distType, maskBits[0]);
-    } else {
-      result =
-          vector::FromElementsOp::create(rewriter, loc, distType, maskBits);
-    }
-    rewriter.replaceOp(op, result);
-    return success();
-  }
-};
-
-/// This pattern distributes a subgroup-level StoreMatrix op to workitem-level.
-struct SgToWiStoreMatrix : public OpConversionPattern<xegpu::StoreMatrixOp> {
-  using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto layout = op.getLayoutAttr();
-    // If no layout, nothing to do.
-    if (!layout)
-      return failure();
-
-    VectorType sgPayloadTy = dyn_cast<VectorType>(op.getData().getType());
-    if (!sgPayloadTy)
-      return rewriter.notifyMatchFailure(
-          op, "the matrix op payload must be a vector type");
-
-    auto loc = op.getLoc();
-    auto offsets = op.getMixedOffsets();
-    if (offsets.empty())
-      return rewriter.notifyMatchFailure(op, "the store op must have offsets");
-
-    FailureOr<VectorType> distPayloadTyOrFailure =
-        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
-    if (failed(distPayloadTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "Failed to distribute matrix op payload based on layout.");
-
-    SmallVector<Value> offsetsAsValues =
-        vector::getAsValues(rewriter, loc, offsets);
-
-    SmallVector<Value> newCoords = offsetsAsValues;
-    if (!op.getSubgroupBlockIoAttr()) {
-      newCoords = computeDistributedCoordsForMatrixOp(
-          rewriter, loc, layout, sgPayloadTy.getShape(), offsetsAsValues);
-      if (newCoords.empty())
-        return rewriter.notifyMatchFailure(
-            op, "Failed to compute distributed coordinates.");
-    }
-
-    SmallVector<int64_t> newConstOffsets(op.getConstOffsets().size(),
-                                         ShapedType::kDynamic);
-    DenseI64ArrayAttr newConstOffsetsAttr =
-        rewriter.getDenseI64ArrayAttr(newConstOffsets);
-
-    xegpu::StoreMatrixOp::create(
-        rewriter, loc, TypeRange{},
-        castValueTo(rewriter, cast<TypedValue<VectorType>>(adaptor.getData()),
-                    distPayloadTyOrFailure.value()),
-        adaptor.getMemDesc(), ValueRange(newCoords), newConstOffsetsAttr,
-        op.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level StoreScatter (xegpu.store) op to
-/// workitem-level.
-///
-/// Example 1 (1D, no chunk size):
-///   layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>
-///   %mask = producer_op : vector<16xi1>
-///   %offset = producer_op : vector<16xindex>
-///   xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
-///     memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// Distributed to:
-///   %mask = producer_op : vector<1xi1>
-///   %offset = producer_op : vector<1xindex>
-///   xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
-///     memref<256xf16>, vector<1xindex>, vector<1xi1>
-///
-/// Example 2 (2D with chunk size, same mask & offset):
-///   layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
-///   xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
-///     vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
-/// Distributed to:
-///   xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
-///     vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
-///
-/// Example 3 (3D with leading unit dims):
-///   layout = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>
-///   %mask = producer_op : vector<1x1x16xi1>
-///   %offset = producer_op : vector<1x1x16xindex>
-///   xegpu.store %payload, %src[%offset], %mask : vector<1x1x16xf16>,
-///     memref<256xf16>, vector<1x1x16xindex>, vector<1x1x16xi1>
-/// Distributed to:
-///   %mask = producer_op : vector<1x1x1xi1>
-///   %offset = producer_op : vector<1x1x1xindex>
-///   xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
-///     memref<256xf16>, vector<1xindex>, vector<1xi1>
-struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
-  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::StoreScatterOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
-    if (!layout)
-      return failure();
-
-    VectorType origValueTy = op.getValueType();
-    if (!origValueTy)
-      return failure();
-
-    // Check that all leading dimensions are unit dimensions.
-    int chunkSize = op.getChunkSize().value_or(1);
-    int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
-    ArrayRef<int64_t> shape = origValueTy.getShape();
-    if (llvm::any_of(shape.take_front(origValueTy.getRank() - effectiveVecRank),
-                     [](int64_t d) { return d != 1; }))
-      return rewriter.notifyMatchFailure(
-          op, "Only unit dimensions allowed for the leading "
-              "dimensions of the store vector!");
-
-    auto distValueTyOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(layout, origValueTy);
-    if (failed(distValueTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op,
-          "unable to compute expected workitem vector type from lane layout");
-
-    VectorType distValueTy = distValueTyOrFailure.value();
-    VectorType distValueTy1D = VectorType::get({distValueTy.getNumElements()},
-                                               distValueTy.getElementType());
-
-    Value distValue = adaptor.getValue();
-    if (distValue.getType() != distValueTy1D)
-      distValue = castValueTo(rewriter, cast<TypedValue<VectorType>>(distValue),
-                              distValueTy1D);
-
-    // Flatten offsets and mask to 1D to match the 1D value type.
-    Value distOffsets = adaptor.getOffsets();
-    auto distOffsetsTy = cast<VectorType>(distOffsets.getType());
-    VectorType offsetsTy1D = VectorType::get({distOffsetsTy.getNumElements()},
-                                             distOffsetsTy.getElementType());
-    distOffsets = castValueTo(
-        rewriter, cast<TypedValue<VectorType>>(distOffsets), offsetsTy1D);
-
-    Value distMask = adaptor.getMask();
-    auto distMaskTy = cast<VectorType>(distMask.getType());
-    VectorType maskTy1D = VectorType::get({distMaskTy.getNumElements()},
-                                          distMaskTy.getElementType());
-    distMask =
-        castValueTo(rewriter, cast<TypedValue<VectorType>>(distMask), maskTy1D);
-
-    Value distDest = adaptor.getDest();
-    xegpu::StoreScatterOp::create(rewriter, op.getLoc(), distValue, distDest,
-                                  distOffsets, distMask, op.getChunkSizeAttr(),
-                                  op.getL1HintAttr(), op.getL2HintAttr(),
-                                  op.getL3HintAttr(), /*layout=*/nullptr);
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-
-/// Distribute a vector::StepOp to workitem-level.
-/// The layout must have exactly 1 effective lane dimension.
-/// We completely resolve the vector::StepOp by computing the lane_data-sized
-/// subranges.
-struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
-  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr resultLayout =
-        xegpu::getTemporaryLayout(op->getResult(0));
-    if (!resultLayout || !resultLayout.isForSubgroup())
-      return rewriter.notifyMatchFailure(
-          op, "the result vector of the step op lacks subgroup layout");
-
-    auto loc = op.getLoc();
-    auto stepResultVecTy = op.getResult().getType();
-    auto wiShapeOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
-    if (failed(wiShapeOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "unable to compute workitem vector type from the layout");
-    VectorType newVecTy = wiShapeOrFailure.value();
-
-    Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
-                                         /*upperBound=*/mlir::IntegerAttr());
-    auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
-        rewriter, loc, laneId, stepResultVecTy.getShape());
-    if (failed(laneDataBlockCoords))
-      return rewriter.notifyMatchFailure(
-          op, "failed to compute lane data block coordinates");
-
-    auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
-    auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
-    assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
-           newVecTy.getNumElements() / laneDataBlockLength);
-    SmallVector<Value> stepVals;
-    // For each lane_data block, reconstruct its sub-range
-    // from the range of SG-level vector.step.Example: vector.step
-    // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
-    // vector<16xindex>
-    // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
-    // The blocks are round-robin distributed, so logical lane id 0
-    // holds values [0,1, 8,9].
-    for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
-      auto laneDataBlockStartCoord = laneDataBlockCoords[0];
-      stepVals.push_back(laneDataBlockStartCoord);
-      for (int i = 1; i < laneDataBlockLength; ++i) {
-        auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
-        stepVals.push_back(arith::AddIOp::create(
-            rewriter, loc, laneDataBlockStartCoord, offset));
-      }
-    }
-    assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
-           "Expecting the number of step values to match the number of "
-           "elements in the vector");
-    auto stepOpVal =
-        vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
-    rewriter.replaceOp(op, stepOpVal);
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level vector.extract op to workitem-level. Only
-/// handles sub-vector extraction (result is VectorType, not scalar).
-struct SgToWiVectorExtract : public OpConversionPattern<vector::ExtractOp> {
-  using OpConversionPattern<vector::ExtractOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::ExtractOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Only handle vector results (not scalar extraction).
-    auto resultType = dyn_cast<VectorType>(op.getType());
-    if (!resultType)
-      return rewriter.notifyMatchFailure(op, "scalar extract not supported");
-
-    xegpu::DistributeLayoutAttr layout =
-        xegpu::getTemporaryLayout(op->getOpResult(0));
-    if (!layout || !layout.isForSubgroup())
-      return failure();
-
-    // This implementation assumes distribution only happens on the innermost
-    // dimension. Verify that lane_layout[0...n-2] are all unit.
-    auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
-    if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
-                     [](int64_t v) { return v != 1; }))
-      return rewriter.notifyMatchFailure(
-          op, "only innermost dimension distribution is supported for "
-              "vector.extract");
-
-    auto newOp = vector::ExtractOp::create(
-        rewriter, op.getLoc(), adaptor.getSource(), op.getMixedPosition());
-    rewriter.replaceOp(op, newOp.getResult());
-    return success();
-  }
-};
-
-/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
-struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
-  using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr resultLayout =
-        xegpu::getTemporaryLayout(op->getOpResult(0));
-    if (!resultLayout || !resultLayout.isForSubgroup())
-      return rewriter.notifyMatchFailure(
-          op, "the result vector of the shape_cast op lacks subgroup layout");
-
-    auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
-        resultLayout, op.getResultVectorType());
-    if (failed(resultDistTypeOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "failed to get distributed vector type for result");
-
-    Value source = adaptor.getSource();
-    auto newShapeCast = vector::ShapeCastOp::create(
-        rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
-    rewriter.replaceOp(op, newShapeCast);
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level vector.extract_strided_slice op to
-/// workitem-level. If the result is distributed, the offsets and sizes are
-/// adjusted to match the distributed types.
-struct SgToWiVectorExtractStridedSlice
-    : public OpConversionPattern<vector::ExtractStridedSliceOp> {
-  using OpConversionPattern<vector::ExtractStridedSliceOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::ExtractStridedSliceOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr resultLayout =
-        xegpu::getTemporaryLayout(op->getOpResult(0));
-    if (!resultLayout || !resultLayout.isForSubgroup())
-      return failure();
-
-    VectorType resultType = op.getType();
-    auto distResultTyOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, resultType);
-    if (failed(distResultTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "unable to compute distributed vector type from lane layout");
-    VectorType distResultTy = *distResultTyOrFailure;
-
-    SmallVector<int64_t> distributedDims =
-        getDistributedDims(resultType, distResultTy);
-
-    // Collect updated sizes, offsets, strides. Pad to full source rank.
-    int64_t sourceRank = op.getSourceVectorType().getRank();
-    SmallVector<Attribute> updatedSizes =
-        llvm::map_to_vector(op.getSizes(), [](Attribute attr) { return attr; });
-    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
-        op.getOffsets(), [](Attribute attr) { return attr; });
-    SmallVector<Attribute> updatedStrides = llvm::map_to_vector(
-        op.getStrides(), [](Attribute attr) { return attr; });
-    for (int64_t i = op.getSizes().size(); i < sourceRank; ++i) {
-      updatedSizes.push_back(
-          rewriter.getI64IntegerAttr(op.getSourceVectorType().getDimSize(i)));
-      updatedOffsets.push_back(rewriter.getI64IntegerAttr(0));
-      updatedStrides.push_back(rewriter.getI64IntegerAttr(1));
-    }
-
-    // If the result is distributed, adjust offsets and sizes in the
-    // distributed dimension.
-    if (!distributedDims.empty()) {
-      if (distributedDims.size() != 1)
-        return rewriter.notifyMatchFailure(
-            op, "only single dimension distribution is supported");
-      int64_t distDim = distributedDims[0];
-      const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
-      if (!uArch)
-        return rewriter.notifyMatchFailure(
-            op, "target attribute required to determine subgroup size");
-      int subgroupSize = uArch->getSubgroupSize();
-      auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
-      if (!sourceLayout || sourceLayout.getEffectiveLaneLayoutAsInt().empty())
-        return rewriter.notifyMatchFailure(
-            op, "source of extract_strided_slice lacks distribution layout");
-      int sourceDistrDimSize = op.getSourceVectorType().getShape()[distDim];
-      if (sourceDistrDimSize % subgroupSize != 0)
-        return rewriter.notifyMatchFailure(
-            op, "source size along distributed dim is not a multiple of "
-                "subgroup size");
-      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
-      // Only check lane_data for the distributed dimension. Non-distributed
-      // dimensions may have non-unit lane_data (e.g., packed layouts).
-      if (distDim < static_cast<int64_t>(sourceLaneData.size()) &&
-          sourceLaneData[distDim] != 1)
-        return rewriter.notifyMatchFailure(
-            op, "expecting unit lane data along the distributed dimension");
-      int64_t distrDimOffset =
-          cast<IntegerAttr>(updatedOffsets[distDim]).getInt();
-      if (distrDimOffset % subgroupSize != 0)
-        return rewriter.notifyMatchFailure(
-            op, "offset along distributed dim is not a multiple of "
-                "subgroup size");
-      // Adjust sizes and offsets for the distributed dimension.
-      updatedSizes[distDim] =
-          rewriter.getI64IntegerAttr(distResultTy.getDimSize(distDim));
-      updatedOffsets[distDim] =
-          rewriter.getI64IntegerAttr(distrDimOffset / subgroupSize);
-    }
-
-    auto newOp = vector::ExtractStridedSliceOp::create(
-        rewriter, op.getLoc(), distResultTy, adaptor.getSource(),
-        ArrayAttr::get(rewriter.getContext(), updatedOffsets),
-        ArrayAttr::get(rewriter.getContext(), updatedSizes),
-        ArrayAttr::get(rewriter.getContext(), updatedStrides));
-    rewriter.replaceOp(op, newOp.getResult());
-    return success();
-  }
-};
-
-/// This pattern distributes a subgroup-level `vector.broadcast` op to
-/// workitem-level. The pattern supports three cases:
-///
-/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
-///    vector must have a slice layout of the result. If the distributed source
-///    and target vector types are identical, this lowers to a no-op; otherwise,
-///    it remains a broadcast but operates on distributed vectors.
-///
-/// 2) Broadcast a same-rank vector with identical layouts for source and
-///    target: The source vector must have unit dimensions, and lane_data must
-///    be unit size for those unit dims. This always lowers to a no-op.
-///
-/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
-///    from scalar to distributed result type.
-///
-/// Example 1 (low-rank to high-rank broadcast):
-/// ```
-///   %0 = "some_op"() {layout_result_0 =
-///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
-///     dims = [0]>} : () -> vector<16xf16>
-///   %1 = vector.broadcast %0 {layout_result_0 =
-///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-///     : vector<16xf16> to vector<16x16xf16>
-/// ```
-/// is distributed to:
-/// ```
-///   %0 = "some_op"() : () -> vector<1xf16>
-///   %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
-/// ```
-///
-/// Example 2 (same-rank broadcast, no-op):
-/// ```
-///   %0 = "some_op"() {layout_result_0 =
-///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-///     : () -> vector<16x1xf16>
-///   %1 = vector.broadcast %0 {layout_result_0 =
-///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-///     : vector<16x1xf16> to vector<16x16xf16>
-/// ```
-/// is distributed to (no-op, source already matches distributed result type):
-/// ```
-///   %0 = "some_op"() : () -> vector<16x1xf16>
-///   // broadcast is eliminated, %0 is used directly
-/// ```
-///
-/// Example 3 (scalar to vector broadcast):
-/// ```
-///   %0 = "some_op"() : () -> f16
-///   %1 = vector.broadcast %0 {layout_result_0 =
-///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-///     : f16 to vector<16x16xf16>
-/// ```
-/// is distributed to:
-/// ```
-///   %0 = "some_op"() : f16
-///   %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
-/// ```
-struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
-  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr resultLayout =
-        xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
-    if (!resultLayout || !resultLayout.isForSubgroup())
-      return rewriter.notifyMatchFailure(
-          op, "result does not have subgroup distribute layout");
-
-    VectorType destType = op.getResultVectorType();
-    VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
-
-    xegpu::DistributeLayoutAttr sourceLayout =
-        xegpu::getTemporaryLayout(op->getOpOperand(0));
-
-    if (sourceType) {
-      int64_t rankDiff = destType.getRank() - sourceType.getRank();
-      if (rankDiff > 0) {
-        // Case 1: Low-rank to high-rank broadcast.
-        if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
-          op.emitWarning(
-              "broadcast source layout must be a slice of result layout");
-      } else if (rankDiff == 0) {
-        // Case 2: Same-rank broadcast.
-        auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
-        SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
-                                               broadcastUnitDimsSet.end());
-        assert(sourceLayout.isEqualTo(
-                   sourceLayout.setUnitDimData(broadcastUnitDims)) &&
-               "The sg_data for unit dimensions should be set as 1");
-        sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
-      }
-    } else {
-      // Case 3: Scalar to vector broadcast.
-      if (sourceLayout)
-        return rewriter.notifyMatchFailure(
-            op, "broadcast from scalar must not have a layout attribute");
-    }
-
-    auto destDistType =
-        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
-    if (failed(destDistType))
-      return rewriter.notifyMatchFailure(
-          op, "failed to distribute the result vector type");
-
-    Value source = adaptor.getSource();
-    // If the adapted source already matches the dest dist type, it's a no-op.
-    if (source.getType() == destDistType.value()) {
-      rewriter.replaceOp(op, source);
-      return success();
-    }
-
-    auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
-                                             destDistType.value(), source);
-    rewriter.replaceOp(op, newOp);
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level vector.insert_strided_slice op to
-/// workitem-level. If the dest is distributed, the offsets are adjusted to
-/// match the distributed types.
-struct SgToWiVectorInsertStridedSlice
-    : public OpConversionPattern<vector::InsertStridedSliceOp> {
-  using OpConversionPattern<vector::InsertStridedSliceOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::InsertStridedSliceOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    xegpu::DistributeLayoutAttr resultLayout =
-        xegpu::getTemporaryLayout(op->getOpResult(0));
-    if (!resultLayout || !resultLayout.isForSubgroup())
-      return failure();
-
-    VectorType destType = op.getDestVectorType();
-    auto distDestTyOrFailure =
-        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
-    if (failed(distDestTyOrFailure))
-      return rewriter.notifyMatchFailure(
-          op, "unable to compute distributed vector type from lane layout");
-    VectorType distDestTy = *distDestTyOrFailure;
-
-    SmallVector<int64_t> destDistributedDims =
-        getDistributedDims(destType, distDestTy);
-
-    SmallVector<Attribute> updatedOffsets = llvm::map_to_vector(
-        op.getOffsets(), [](Attribute attr) { return attr; });
-
-    if (!destDistributedDims.empty()) {
-      if (destDistributedDims.size() != 1)
-        return rewriter.notifyMatchFailure(
-            op, "only single dimension distribution is supported");
-      int64_t destDistDim = destDistributedDims[0];
-
-      const uArch *uArch = getUArch(xegpu::getChipStr(op).value_or(""));
-      if (!uArch)
-        return rewriter.notifyMatchFailure(
-            op, "target attribute required to determine subgroup size");
-      int subgroupSize = uArch->getSubgroupSize();
-
-      VectorType srcType = op.getSourceVectorType();
-      // The distributed dim must be in the last k (source rank) dims of dest.
-      int64_t sourceDistDim =
-          destDistDim - (destType.getRank() - srcType.getRank());
-      if (sourceDistDim < 0)
-        return rewriter.notifyMatchFailure(
-            op, "distributed dimension must be in the last k dims of dest");
-
-      auto destLayout = xegpu::getTemporaryLayout(op->getOpOperand(1));
-      auto sourceLayout = xegpu::getTemporaryLayout(op->getOpOperand(0));
-      if (!destLayout || !sourceLayout ||
-          destLayout.getEffectiveLaneLayoutAsInt().empty() ||
-          sourceLayout.getEffectiveLaneLayoutAsInt().empty())
-        return rewriter.notifyMatchFailure(
-            op, "source or dest of insert_strided_slice lacks distribution "
-                "layout");
-
-      auto destLaneData = destLayout.getEffectiveLaneDataAsInt();
-      auto sourceLaneData = sourceLayout.getEffectiveLaneDataAsInt();
-      // Only check lane_data for the distributed dimension. Non-distributed
-      // dimensions may have non-unit lane_data (e.g., packed layouts).
-      if ((destDistDim < static_cast<int64_t>(destLaneData.size()) &&
-           destLaneData[destDistDim] != 1) ||
-          (sourceDistDim < static_cast<int64_t>(sourceLaneData.size()) &&
-           sourceLaneData[sourceDistDim] != 1))
-        return rewriter.notifyMatchFailure(
-            op, "expecting unit lane data along the distributed dimension");
-
-      int64_t srcDistrDimSize = srcType.getDimSize(sourceDistDim);
-      if (srcDistrDimSize % subgroupSize != 0)
-        return rewriter.notifyMatchFailure(
-            op, "source distributed dim size is not a multiple of "
-                "subgroup size");
-
-      int64_t destDistrDimOffset =
-          cast<IntegerAttr>(op.getOffsets()[destDistDim]).getInt();
-      if (destDistrDimOffset % subgroupSize != 0)
-        return rewriter.notifyMatchFailure(
-            op, "offset along distributed dim is not a multiple of "
-                "subgroup size");
-      // Adjust offset for the distributed dimension.
-      updatedOffsets[destDistDim] =
-          rewriter.getI64IntegerAttr(destDistrDimOffset / subgroupSize);
-    }
-
-    auto newOp = vector::InsertStridedSliceOp::create(
-        rewriter, op.getLoc(), distDestTy, adaptor.getValueToStore(),
-        adaptor.getDest(),
-        ArrayAttr::get(rewriter.getContext(), updatedOffsets), op.getStrides());
-    rewriter.replaceOp(op, newOp.getResult());
-    return success();
-  }
-};
-
-/// Distributes a subgroup-level vector.insert op to workitem-level. Only
-/// handles sub-vector insertion (value to store is VectorType, not scalar).
-struct SgToWiVectorInsert : public OpConversionPattern<vector::InsertOp> {
-  using OpConversionPattern<vector::InsertOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::InsertOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // Only handle vector value-to-store (not scalar insertion).
-    auto valueType = dyn_cast<VectorType>(op.getValueToStoreType());
-    if (!valueType)
-      return rewriter.notifyMatchFailure(op, "scalar insert not supported");
-
-    xegpu::DistributeLayoutAttr layout =
-        xegpu::getTemporaryLayout(op->getOpResult(0));
-    if (!layout || !layout.isForSubgroup())
-      return failure();
-
-    // verify that the outer k dimensions (for offsets)
-    // don't have non-unit lane_layout.
-    auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
-    if (llvm::any_of(ArrayRef<int64_t>(laneLayout).drop_back(1),
-                     [](int64_t v) { return v != 1; }))
-      return rewriter.notifyMatchFailure(
-          op, "only innermost dimension distribution is supported for "
-              "vector.insert");
-
-    auto newOp = vector::InsertOp::create(
-        rewriter, op.getLoc(), adaptor.getValueToStore(), adaptor.getDest(),
-        op.getMixedPosition());
-    rewriter.replaceOp(op, newOp.getResult());
-    return success();
-  }
-};
-
-/// Folds a subgroup-level ConvertLayout op with compatible lane layouts.
-struct SgToWiConvertLayout
-    : public OpConversionPattern<xegpu::ConvertLayoutOp> {
-  using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(xegpu::ConvertLayoutOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto inputLayout = op.getInputLayoutAttr();
-    auto targetLayout = op.getTargetLayoutAttr();
-    Type valType = op.getResult().getType();
-
-    if (valType.isIntOrFloat()) {
-      rewriter.replaceOp(op, op.getSource());
-      return success();
-    }
-
-    auto resShape = cast<VectorType>(valType).getShape();
-    SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
-    if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
-                                      xegpu::LayoutKind::Lane)) {
-      return rewriter.notifyMatchFailure(
-          op, "lowering incompatible convert_layout not yet supported");
-    }
-
-    rewriter.replaceOp(op, adaptor.getSource());
-    return success();
-  }
-};
-
-struct XeGPUSgToWiDistributeExperimentalPass
-    : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
-          XeGPUSgToWiDistributeExperimentalPass> {
-  void runOnOperation() override;
-};
-
-} // namespace
-
-void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
-
-  // Recover temporary operand layouts for usage in patterns.
-  Operation *root = getOperation();
-  if (!xegpu::recoverTemporaryLayouts(root)) {
-    signalPassFailure();
-    return;
-  }
-
-  // Collect existing UnrealizedConversionCastOps. These must be preserved.
-  llvm::SmallSetVector<UnrealizedConversionCastOp, 8> existingCasts;
-  root->walk(
-      [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); });
-  // Perform a structural type conversion to convert structural ops to have WI
-  // types. This will insert UnrealizedConversionCastOps to make the IR
-  // valid.
-  auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
-                             mlir::ValueRange inputs,
-                             mlir::Location loc) -> mlir::Value {
-    UnrealizedConversionCastOp castOp =
-        UnrealizedConversionCastOp::create(builder, loc, type, inputs);
-    return castOp.getResult(0);
-  };
-  {
-    ConversionTarget target(getContext());
-    TypeConverter typeConverter;
-    RewritePatternSet patterns(&getContext());
-    typeConverter.addSourceMaterialization(materializeCast);
-    typeConverter.addTargetMaterialization(materializeCast);
-    xegpu::populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
-    scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
-                                                         patterns, target);
-    xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
-        typeConverter, patterns, target);
-    target.addLegalOp<UnrealizedConversionCastOp>();
-    (void)applyPartialConversion(root, target, std::move(patterns));
-  }
-  // Structural type conversion can generate some redundant
-  // UnrealizedConversionCastOps to materialize the SG type from type converted
-  // WI type. These are redundant at this point and can be eliminated by
-  // inserting shape casts instead.
-  // Example:
-  // %1 = UnrealizedConversionCastOp %0 : vector<16x1xf32> to vector<16x16xf32>
-  // %2 = UnrealizedConversionCastOp %1 : vector<16x16xf32> to vector<16xf32>
-  // This can be replaced with:
-  // %2 = vector.shape_cast %0 : vector<16x1xf32> to vector<16xf32>
-  OpBuilder builder(root);
-  root->walk([&](UnrealizedConversionCastOp op) {
-    // If this op existed before, nothing to do.
-    if (existingCasts.contains(op))
-      return;
-    // number of inputs and outputs must be 1.
-    if (op.getNumOperands() != 1 || op.getNumResults() != 1)
-      return;
-    // Both input and output types must be vector types.
-    auto singleInput = op.getInputs()[0];
-    auto inputTy = dyn_cast<VectorType>(singleInput.getType());
-    auto outputTy = dyn_cast<VectorType>(op.getResult(0).getType());
-    if (!inputTy || !outputTy)
-      return;
-
-    // Check if the defining op of the input is also an
-    // UnrealizedConversionCastOp and it has a single user (which is this
-    // op).
-    auto definingOp = singleInput.getDefiningOp<UnrealizedConversionCastOp>();
-    if (!definingOp || !definingOp->hasOneUse())
-      return;
-    auto inputOfDefiningOp = definingOp.getInputs()[0];
-    // If the input of the defining op and output type are both vector types
-    // have same number of elements, insert a shape cast.
-    auto inputOfDefiningOpTy =
-        dyn_cast<VectorType>(inputOfDefiningOp.getType());
-    if (inputOfDefiningOpTy &&
-        inputOfDefiningOpTy.getNumElements() == outputTy.getNumElements()) {
-      builder.setInsertionPoint(op);
-      auto shapeCast = vector::ShapeCastOp::create(builder, op.getLoc(),
-                                                   outputTy, inputOfDefiningOp);
-      op.replaceAllUsesWith(ValueRange{shapeCast.getResult()});
-      return;
-    }
-  });
-  // At this point, we will have some dead UnrealizedConversionCastOps. Just
-  // erase them.
-  bool changed = true;
-  while (changed) {
-    changed = false;
-    root->walk([&](UnrealizedConversionCastOp op) {
-      // Skip existing casts.
-      if (existingCasts.contains(op))
-        return;
-      if (op.use_empty()) {
-        op.erase();
-        changed = true;
-      }
-    });
-  }
-}
-
-void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
-    TypeConverter &typeConverter) {
-  // Any type other than TensorDescType and VectorType are legal as is.
-  typeConverter.addConversion([](Type type) -> std::optional<Type> {
-    if (!isa<TensorDescType, VectorType>(type))
-      return type;
-    return std::nullopt;
-  });
-  // For TensorDescType, drop the layout attribute if any.
-  typeConverter.addConversion([](TensorDescType type) -> Type {
-    if (type.getLayoutAttr()) {
-      return type.dropLayouts();
-    }
-    return type;
-  });
-  // For VectorType, check if there is a distribute layout attribute on the
-  // value. If so, convert to the distributed vector type based on the layout.
-  typeConverter.addConversion([](Value v) -> std::optional<Type> {
-    auto type = v.getType();
-    // If value is not vector type, nothing to do.
-    if (!isa<VectorType>(type))
-      return std::nullopt;
-    auto layout = xegpu::getDistributeLayoutAttr(v);
-    if (!layout || !layout.isForSubgroup())
-      return type;
-    // Vector type is distributed based on lane layout.
-    auto newTyOrFailure =
-        getDistVecTypeBasedOnLaneLayout(layout, cast<VectorType>(type));
-    if (failed(newTyOrFailure))
-      return type;
-    return *newTyOrFailure;
-  });
-}
-
-void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
-    TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target) {
-  populateXeGPUSgToWiDistributeTypeConversions(typeConverter);
-  // CreateNdDescOp is legal only if its result type has no layout attribute.
-  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp>(
-      [&](xegpu::CreateNdDescOp op) { return !op.getType().getLayoutAttr(); });
-  // Any anchor XeGPU op is legal only if it has no anchor layout.
-  target.addDynamicallyLegalDialect<xegpu::XeGPUDialect>([](Operation *op) {
-    auto anchorOp = dyn_cast<AnchorLayoutInterface>(op);
-    if (!anchorOp)
-      return true;
-    return !anchorOp.getAnchorLayout();
-  });
-  // Arith constants are legal only if they have no temporary layout attribute.
-  target.addDynamicallyLegalOp<arith::ConstantOp>(
-      [=](arith::ConstantOp op) -> bool {
-        // If the result type is not a vector, it's legal.
-        if (!isa<VectorType>(op.getResult().getType()))
-          return true;
-        return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
-      });
-  // In math and arith dialects, only handle elementwise ops with a single
-  // result and with a result layout attribute.
-  target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
-      [=](Operation *op) -> std::optional<bool> {
-        // Only handle elementwise mappable ops
-        if (!OpTrait::hasElementwiseMappableTraits(op))
-          return true;
-        // Only handle ops with single vector result
-        if (op->getNumResults() != 1)
-          return true;
-
-        VectorType resultType =
-            dyn_cast<VectorType>(op->getResult(0).getType());
-        if (!resultType)
-          return true;
-
-        // Check if all operands are vectors of the same shape
-        for (Value operand : op->getOperands()) {
-          VectorType operandType = dyn_cast<VectorType>(operand.getType());
-          if (!operandType || operandType.getShape() != resultType.getShape()) {
-            return true;
-          }
-        }
-        return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
-      });
-  // vector::ReductionOp is legal only if its source has no distribute layout
-  // attribute.
-  target.addDynamicallyLegalOp<vector::ReductionOp>(
-      [=](vector::ReductionOp op) -> bool {
-        auto layout = xegpu::getDistributeLayoutAttr(op.getVector());
-        return !layout;
-      });
-  // vector::MultiDimReductionOp op legality.
-  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
-      [=](vector::MultiDimReductionOp op) -> bool {
-        return !isValidSubgroupMultiReductionOp(op);
-      });
-  target.addDynamicallyLegalOp<vector::CreateMaskOp, vector::ConstantMaskOp,
-                               vector::TransposeOp, vector::BitCastOp,
-                               vector::ShapeCastOp, vector::StepOp,
-                               vector::BroadcastOp>([=](Operation *op) -> bool {
-    return !xegpu::getTemporaryLayout(op->getOpResult(0));
-  });
-  target.addDynamicallyLegalOp<vector::ExtractOp>(
-      [=](vector::ExtractOp op) -> bool {
-        if (!isa<VectorType>(op.getType()))
-          return true;
-        return !xegpu::getTemporaryLayout(op->getOpResult(0));
-      });
-  target.addDynamicallyLegalOp<vector::InsertOp>(
-      [=](vector::InsertOp op) -> bool {
-        return !xegpu::getTemporaryLayout(op->getOpResult(0));
-      });
-  target.addDynamicallyLegalOp<vector::ExtractStridedSliceOp>(
-      [=](vector::ExtractStridedSliceOp op) -> bool {
-        return !xegpu::getTemporaryLayout(op->getOpResult(0));
-      });
-  target.addDynamicallyLegalOp<vector::InsertStridedSliceOp>(
-      [=](vector::InsertStridedSliceOp op) -> bool {
-        return !xegpu::getTemporaryLayout(op->getOpResult(0));
-      });
-  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
-  patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
-               SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
-               SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
-               SgToWiMultiDimReduction, SgToWiVectorExtract, SgToWiVectorInsert,
-               SgToWiVectorExtractStridedSlice, SgToWiVectorInsertStridedSlice,
-               SgToWiLoadMatrix, SgToWiStoreMatrix, SgToWiConvertLayout,
-               SgToWiVectorTranspose, SgToWiVectorBitcast, SgToWiVectorStep,
-               SgToWiVectorShapeCast, SgToWiBroadcast,
-               SgToWiCreateMask<vector::CreateMaskOp>,
-               SgToWiCreateMask<vector::ConstantMaskOp>>(typeConverter,
-                                                         patterns.getContext());
-}

>From 6845ef13d5ccfe86004bbe8c00f5e5b9446e9526 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 23:16:51 +0000
Subject: [PATCH 17/21] remove isEvenlyDistributable

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 76 -------------------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 11 ---
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  3 -
 3 files changed, 90 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eaa43c02946d8..cba29b1a926d0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -121,74 +121,6 @@ static SmallVector<SmallVector<int64_t>> genStaticCoordinates(
   return coordinates;
 }
 
-// Checks if the given shape can be evenly distributed based on the layout
-// and data factors provided by the LayoutAttr.
-bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
-                                         xegpu::DistributeLayoutAttr attr) {
-  assert(attr && "Layout attribute is missing.");
-
-  // Checks whether the given shape can be evenly distributed using the
-  // specified layout and data attributes. If successful, it returns the work
-  // size for each compute unit; otherwise, it returns `std::nullopt`. The work
-  // size per compute unit is calculated as follows:
-  //   - If `data` is null: newShape[i] = shape[i] / layout[i]
-  //   - If `data` is not null: newShape[i] = data[i]
-  // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
-  // smaller than `layout[i] * data[i]`, allowing multiple compute units to
-  // share the data.
-  auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
-                           SmallVector<int64_t> layout,
-                           SmallVector<int64_t> data,
-                           bool rr = true) -> optional<SmallVector<int64_t>> {
-    llvm::SmallVector<int64_t> newShape(shape);
-    if (layout.size()) {
-      if (layout.size() != shape.size())
-        return std::nullopt;
-      auto ratio = computeShapeRatio(shape, layout);
-      if (ratio.has_value()) {
-        newShape = ratio.value();
-      } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
-        return std::nullopt;
-      }
-      // Round-robin case: continue with original newShape
-    }
-
-    if (data.size()) {
-      if (data.size() != shape.size())
-        return std::nullopt;
-      auto ratio = computeShapeRatio(newShape, data);
-      if (!ratio.has_value() && rr)
-        ratio = computeShapeRatio(data, newShape);
-      if (!ratio.has_value())
-        return std::nullopt;
-
-      // if data is not null, we always return it for next phase.
-      newShape = data;
-    }
-    return newShape;
-  };
-
-  // check the sgLayout and sgData
-  auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
-                                    attr.getEffectiveSgDataAsInt());
-  if (!maybeSgShape)
-    return false;
-  auto sgShape = maybeSgShape.value();
-
-  // check InstData, it neither have layout nor need round-robin
-  auto maybeInstShape =
-      tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
-  if (!maybeInstShape)
-    return false;
-  auto instShape = maybeInstShape.value();
-
-  // check LaneLayout and LaneData
-  auto maybeLaneShape =
-      tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
-                    attr.getEffectiveLaneDataAsInt());
-  return maybeLaneShape.has_value();
-}
-
 //===----------------------------------------------------------------------===//
 // XeGPU_BlockTensorDescAttr
 //===----------------------------------------------------------------------===//
@@ -1448,14 +1380,6 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
                << "expected last dim of lane_data to be a multiple of: "
                << chunkAlignmentFactor;
     }
-
-    if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
-      std::string shapeStr;
-      llvm::raw_string_ostream stream(shapeStr);
-      llvm::interleaveComma(shape, stream);
-      return emitError() << "cannot distribute [" << shapeStr << "] using "
-                         << layoutAttr;
-    }
   }
   return success();
 }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index d6981052a7a5c..d3fdbb81a1cbd 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1101,17 +1101,6 @@ LogicalResult ConvertLayoutOp::verify() {
     return emitOpError("expected input layout and target layout be WgLayout or "
                        "SgLayout at the same time.");
 
-  // Type srcType = getSource().getType();
-  // if (llvm::isa<VectorType>(srcType)) {
-  //   auto shape = llvm::cast<VectorType>(srcType).getShape();
-  //   if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
-  //     return emitOpError(
-  //         "invalid input layout, data cannot be evenly distributed.");
-
-  //   if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
-  //     return emitOpError(
-  //         "invalid target layout, data cannot be evenly distributed.");
-  // }
   return mlir::success();
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index dabdcb61f0500..57188fae5e8c8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -498,9 +498,6 @@ struct WgToSgVectorBroadcastOp
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
-    if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
-      return failure();
-
     SmallVector<Value> newBroadcastOps;
     auto distSource = adaptor.getOperands().front();
     int numDistributions = count / distSource.size();

>From 0c7917e246f5f78f38b1872ad97e1cab4fea3e47 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Apr 2026 23:17:41 +0000
Subject: [PATCH 18/21] remove isEvenlyDistributable

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index c173b93face98..84fd8f9e0060c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -38,10 +38,6 @@ def XeGPU_Dialect : Dialect {
     let useDefaultAttributePrinterParser = true;
 
     let extraClassDeclaration = [{
-      /// Checks if the given shape can be evenly distributed based on the layout
-      /// and data factors provided by the LayoutAttr.
-      static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
-
       /// drops/slices the shape in the specified dims, and return the rest. e.g.,
       /// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
       template<typename T, typename U>

>From b688a2f15c140d2538cd8c40c97ae1a875500bbb Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 16 Apr 2026 18:23:40 +0000
Subject: [PATCH 19/21] refactor adding removeTemporaryLayoutAttrs to utility

---
 .../mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h |  5 +++++
 .../Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp    | 12 ++++++++++++
 .../XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp     | 11 +----------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp       | 13 +++----------
 .../XeGPUSgToWiDistributeExperimental.cpp           | 11 ++---------
 .../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp    | 11 ++---------
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp      | 10 +---------
 7 files changed, 26 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index a7ca51bc4fdf7..83eb939cf1bec 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -61,6 +61,11 @@ void removeLayoutAttr(const T &operandOrResult);
 /// applied recursively to the contained operations
 void removeLayoutAttrs(Operation *op);
 
+/// Removes the temporary layout attributes for each OpOperand and OpResult of
+/// the given operation. Recursive for contained operations if the given
+/// operation contains regions.
+void removeTemporaryLayoutAttrs(Operation *op);
+
 /// Updates the NamedAttribute sequence by dropping sg-layout and
 /// sg-data information from any DistributeLayoutAttr found.
 SmallVector<NamedAttribute>
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 8f319bd161798..0dd9348bc06c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -306,6 +306,18 @@ void xegpu::removeLayoutAttrs(Operation *op) {
   });
 }
 
+void xegpu::removeTemporaryLayoutAttrs(Operation *op) {
+  op->walk([&](Operation *nestOp) {
+    SmallVector<StringAttr> attrsToRemove;
+    for (auto namedAttr : nestOp->getDiscardableAttrs()) {
+      if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
+        attrsToRemove.push_back(namedAttr.getName());
+    }
+    for (auto attrName : attrsToRemove)
+      nestOp->removeDiscardableAttr(attrName);
+  });
+}
+
 /// Infers the source layout attribute for a broadcast operation given the
 /// result layout attribute, result shape, source shape.
 xegpu::DistributeLayoutAttr
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 3496756e8a6d3..8ade936724480 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -598,16 +598,7 @@ struct XeGPUPeepHoleOptimizerPass final
     RewritePatternSet emptyPatterns(ctx);
     (void)applyPatternsGreedily(getOperation(), std::move(emptyPatterns));
 
-    // Remove the temporary layout after all patterns are applied.
-    getOperation()->walk([](Operation *op) {
-      SmallVector<StringAttr> attrsToRemove;
-      for (auto namedAttr : op->getDiscardableAttrs()) {
-        if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
-          attrsToRemove.push_back(namedAttr.getName());
-      }
-      for (auto attrName : attrsToRemove)
-        op->removeDiscardableAttr(attrName);
-    });
+    xegpu::removeTemporaryLayoutAttrs(getOperation());
   }
 };
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index ff9ff4937c293..630a314b1ce40 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1694,16 +1694,9 @@ LogicalResult xegpu::resolveLayoutConflicts(Operation *target) {
 }
 
 void XeGPUPropagateLayoutPass::runOnOperation() {
-  // Clean up temporary layout attributes
-  getOperation()->walk([](Operation *op) {
-    SmallVector<StringAttr> attrsToRemove;
-    for (auto namedAttr : op->getDiscardableAttrs()) {
-      if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
-        attrsToRemove.push_back(namedAttr.getName());
-    }
-    for (auto attrName : attrsToRemove)
-      op->removeDiscardableAttr(attrName);
-  });
+
+  xegpu::removeTemporaryLayoutAttrs(getOperation());
+
   xegpu::LayoutKind layoutKind;
   if (this->layoutKind == "lane") {
     layoutKind = xegpu::LayoutKind::Lane;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 9cb9d34401216..c153db431c035 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -1609,15 +1609,8 @@ void XeGPUSgToWiDistributeExperimentalPass::runOnOperation() {
       }
     });
   }
-  getOperation()->walk([](Operation *op) {
-    SmallVector<StringAttr> attrsToRemove;
-    for (auto namedAttr : op->getDiscardableAttrs()) {
-      if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
-        attrsToRemove.push_back(namedAttr.getName());
-    }
-    for (auto attrName : attrsToRemove)
-      op->removeDiscardableAttr(attrName);
-  });
+
+  xegpu::removeTemporaryLayoutAttrs(getOperation());
 }
 
 void xegpu::populateXeGPUSgToWiDistributeTypeConversions(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c6c48515fcf0c..1dab7d9808756 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -2280,13 +2280,6 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       op->erase();
     return WalkResult::advance();
   });
-  getOperation()->walk([](Operation *op) {
-    SmallVector<StringAttr> attrsToRemove;
-    for (auto namedAttr : op->getDiscardableAttrs()) {
-      if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
-        attrsToRemove.push_back(namedAttr.getName());
-    }
-    for (auto attrName : attrsToRemove)
-      op->removeDiscardableAttr(attrName);
-  });
+
+  xegpu::removeTemporaryLayoutAttrs(getOperation());
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 57188fae5e8c8..3fadc593849bf 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1781,13 +1781,5 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
           applyPartialConversion(getOperation(), target, std::move(patterns))))
     return signalPassFailure();
 
-  getOperation()->walk([](Operation *op) {
-    SmallVector<StringAttr> attrsToRemove;
-    for (auto namedAttr : op->getDiscardableAttrs()) {
-      if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
-        attrsToRemove.push_back(namedAttr.getName());
-    }
-    for (auto attrName : attrsToRemove)
-      op->removeDiscardableAttr(attrName);
-  });
+  xegpu::removeTemporaryLayoutAttrs(getOperation());
 }

>From e58cc8695df8f5adc0ab774cc48cbd0298c95bac Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 21 Apr 2026 00:21:26 +0000
Subject: [PATCH 20/21] address feedback

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 167 +++++++++++-------
 1 file changed, 107 insertions(+), 60 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 0dd9348bc06c2..616fee1db2107 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
@@ -82,41 +83,12 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
   return out;
 }
 
-// Prerequisite for Layout Recovery
-// It relies on the following invariant:
-// 1. there is no layout conflict between different uses of the same definition.
-// 2. each definition has a well-defined layout requirement at its use point.
-//     - Every definition must have at least one use that appears after it in
-//     topological order.
-//     - If a definition has no such use (e.g., a loop result or region output),
-//     an explicit convert_layout operation is inserted to create a use.
-//     - Only the result of convert_layout is permitted to have no subsequent
-//     use.
-
-// The recovery proceeds by scanning the operation in reverse topological order
-// as follows:
-//    For regular operations: First the result layouts are propagated from uses.
-//      Then the result layouts are propagated to operands.
-//
-//    For region operations (e.g., loops):
-//       - When backward propagation reaches a region op, it sets the layout of
-//       the region op’s results according to use points like regular ops.
-//       - Then, the result layouts (such as a loop output) are propagated to
-//       their corresponding operands in the yield.
-//       - When backward propagation reaches the first operation inside the
-//       region, the pass examines the region op’s initialization list,
-//       propagating from region arguments to the corresponding initialization
-//       operands.
-//       - This ensures that layouts are consistently propagated
-//       across region boundaries while preserving a single well-defined use for
-//       each definition at the region-op level.
-
-// the inner function for recoverTemporaryLayouts is a recursive function
+// the walkRegionBackward() is a recursive function
 // the input rootOp is the function operation, which is also a region op.
-// it recursivley process the region op in reverse topological order.
-
+// it recursively processes the region op in reverse topological order.
 static void walkRegionBackward(Region &region,
                                llvm::function_ref<void(Operation *)> visit) {
+
   // blocks: back -> front
   for (Block &block : llvm::reverse(region)) {
     // ops: back -> front, early-inc so visit() may erase current op safely
@@ -129,6 +101,42 @@ static void walkRegionBackward(Region &region,
       visit(&op);
     }
   }
+  // // Process blocks in post-order (reverse topological order: dominated
+  // before dominators)
+  // // For single-block regions, this is just that block
+  // // For control flow, this ensures dominated blocks are processed before
+  // dominators
+
+  // // Step 1: Get reverse post-order traversal
+  // llvm::ReversePostOrderTraversal<Region *> rpot(&region);
+  // // Step 2: Collect into vector
+  // llvm::SmallVector<Block *> blocks(rpot.begin(), rpot.end());
+  // // Step 3: Reverse it to get post-order (reverse topological order)
+  // for (Block *block : llvm::reverse(blocks)) {
+  //   // ops: back -> front, with early-inc so visit() may erase current op
+  //   safely
+  //   // We need to collect operations first because nested region walks might
+  //   modify the block llvm::SmallVector<Operation *> ops; for (Operation &op :
+  //   llvm::reverse(*block))
+  //     ops.push_back(&op);
+
+  //   for (Operation *op : ops) {
+  //     // Check if op is still alive (might have been erased by nested walk)
+  //     if (!op->isRegistered())
+  //       continue;
+
+  //     // make sure we first visit inside the region op (so yield op first)
+  //     // and then move to region op itself
+  //     // Note: Region iteration order doesn't affect correctness since each
+  //     // region is processed in reverse topological order independently
+  //     for (Region &nested : op->getRegions())
+  //       walkRegionBackward(nested, visit);
+
+  //     // Check again if op is still alive before visiting
+  //     if (op->isRegistered())
+  //       visit(op);  // Can safely erase op now
+  //   }
+  // }
 }
 
 static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
@@ -137,6 +145,7 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
     if (auto tmpLayout = xegpu::getDistributeLayoutAttr(use)) {
       if (!layout)
         layout = tmpLayout;
+      break;
     }
   }
   return layout;
@@ -145,7 +154,7 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
 // For regular operations: First the result layouts are propagated from uses.
 // Then the result layouts are propagated to uses (operands).
 static void propagateResultsToRegularOperands(Operation *op) {
-  if (op->getNumResults() == 0)
+  if (op->getNumResults() == 0 || op->getNumResults() > 1)
     return;
 
   OpResult result = op->getResult(0);
@@ -157,9 +166,6 @@ static void propagateResultsToRegularOperands(Operation *op) {
   // For vector type, we attach the layout as an attribute to op.
   if (auto tensorDescTy = dyn_cast<xegpu::TensorDescType>(resultType)) {
     auto layout = tensorDescTy.getLayoutAttr();
-    // TODO: remove the layout check. The tensorDescType's layout is treated as
-    // temporary layout, which needs to be set by layout recovery.
-    // allow it now to pass some legacy test case
     if (!layout) {
       auto typeWithLayout = xegpu::TensorDescType::get(
           tensorDescTy.getContext(), tensorDescTy.getShape(),
@@ -167,25 +173,21 @@ static void propagateResultsToRegularOperands(Operation *op) {
       result.setType(typeWithLayout);
     }
   }
-
-  xegpu::setTemporaryLayout(result, resLayout);
+  if (isa<VectorType>(resultType) && resLayout)
+    xegpu::setTemporaryLayout(result, resLayout);
 
   for (OpOperand &opr : op->getOpOperands()) {
-    // Layouts are needed for vector type only.
     xegpu::DistributeLayoutAttr operandLayout =
         xegpu::inferSourceLayoutFromResult(opr, resLayout);
-    if (!isa<VectorType>(opr.get().getType()))
-      continue;
-
-    xegpu::setTemporaryLayout(opr, operandLayout);
+    if (isa<VectorType>(opr.get().getType()) && operandLayout)
+      xegpu::setTemporaryLayout(opr, operandLayout);
   }
 }
 
+// propagate layout from region results to yield operands. This set the
+// temproary layout for reguion results and yield operands.
 static void propagateRegionResultsToYieldOperands(
     mlir::RegionBranchTerminatorOpInterface yieldOp) {
-  if (isa<func::FuncOp>(yieldOp->getParentOp()))
-    return;
-
   auto regionBranchOp =
       dyn_cast<RegionBranchOpInterface>(yieldOp->getParentOp());
   if (!regionBranchOp)
@@ -205,10 +207,22 @@ static void propagateRegionResultsToYieldOperands(
       xegpu::setTemporaryLayout(result, resultLayouts[i]);
   }
 
-  // Use getSuccessorOperands to find which operands of the terminator
-  // flow to a successor. This handles index offsets automatically (e.g.,
-  // scf.condition's predicate at operand #0 is excluded).
-  // Pick the first successor to determine the operand range.
+  // We are interested in the parent successor, i.e., the branch that exits
+  // the region and forwards operands to the parent op's results. This handles
+  // index offsets automatically (e.g., scf.condition's predicate at operand #0
+  // is excluded).
+  // SmallVector<RegionSuccessor> successors;
+  // SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
+  // yieldOp.getSuccessorRegions(operandAttrs, successors);
+  // auto *parentSuccessor = llvm::find_if(
+  //     successors, [](const RegionSuccessor &s) { return s.isParent(); });
+  // assert(parentSuccessor != successors.end() &&
+  //        "terminator must have the parent op as a successor");
+
+  // OperandRange succOps = yieldOp.getSuccessorOperands(*parentSuccessor);
+  // unsigned beginIdx = succOps.getBeginOperandIndex();
+  // unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
+
   SmallVector<RegionSuccessor> successors;
   SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
   yieldOp.getSuccessorRegions(operandAttrs, successors);
@@ -226,6 +240,10 @@ static void propagateRegionResultsToYieldOperands(
   }
 }
 
+// propagate layout from region arguments to region op's init operands. This set
+// the temproary layout for region arguments and init operands.
+// For while op containing multipel regions, different segements of init
+// operands might mapped to diferent region arguments.
 static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
   // Iterate all regions of the region op. For each block argument that has a
   // layout (determined from its use points), trace back to find the
@@ -253,6 +271,35 @@ static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
   }
 }
 
+// Prerequisite for Layout Recovery
+// It relies on the following invariant:
+// 1. there is no layout conflict between different uses of the same definition.
+// 2. each definition has a well-defined layout requirement at its use point.
+//     - Every definition must have at least one use that appears after it in
+//     topological order.
+//     - TODO: If a definition has no such use (e.g., a loop result or region
+//     output), an explicit convert_layout operation is inserted to create a
+//     use.
+//     - Only the result of convert_layout is permitted to have no subsequent
+//     use.
+//
+// The recovery proceeds by scanning the operation in reverse topological order
+// as follows:
+//    For regular operations: First the result layouts are propagated from uses.
+//      Then the result layouts are propagated to operands.
+//
+//    For region operations (e.g., loops):
+//       - When backward propagation reaches a region op, it sets the layout of
+//       the region op’s results according to use points like regular ops.
+//       - Then, the result layouts (such as a loop output) are propagated to
+//       their corresponding operands in the yield.
+//       - When backward propagation reaches the first operation inside the
+//       region, the pass examines the region op’s initialization list,
+//       propagating from region arguments to the corresponding initialization
+//       operands.
+//       - This ensures that layouts are consistently propagated
+//       across region boundaries while preserving a single well-defined use for
+//       each definition at the region-op level.
 bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
   auto processFunc = [&](Region &body, StringRef funcName) {
     walkRegionBackward(body, [&](Operation *op) {
@@ -1309,7 +1356,7 @@ xegpu::DistributeLayoutAttr
 xegpu::inferSourceLayoutFromResult(OpOperand &operand,
                                    xegpu::DistributeLayoutAttr resLayout) {
   if (!resLayout)
-    return xegpu::DistributeLayoutAttr();
+    return nullptr;
   Operation *op = operand.getOwner();
   unsigned idx = operand.getOperandNumber();
 
@@ -1317,7 +1364,7 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
   if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
     auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
     if (!srcTy)
-      return xegpu::DistributeLayoutAttr();
+      return nullptr;
     return xegpu::inferBroadcastSourceLayout(
         resLayout, broadcast.getResultVectorType().getShape(),
         srcTy.getShape());
@@ -1376,15 +1423,15 @@ xegpu::inferSourceLayoutFromResult(OpOperand &operand,
                                              transpose.getPermutation());
   }
 
-  if (isa<VectorType>(operand.get().getType()) &&
-      !dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
-    // For elementwise operations, all operands must have the same layout as the
-    // result.
-    // if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() ==
-    // 1) {
+  // For vector::ExtractStridedSliceOp, simply return result layout
+  if (dyn_cast<vector::ExtractStridedSliceOp>(op))
     return resLayout;
-  }
-  return xegpu::DistributeLayoutAttr();
+  // For elementwise operations, all operands must have the same layout as the
+  // result.
+  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
+    return resLayout;
+
+  return nullptr;
 }
 
 xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {

>From d46b90044e2cf77a7e14012194c802659717149a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 21 Apr 2026 03:23:14 +0000
Subject: [PATCH 21/21] adding tests and fix issues

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 156 ++++++++----------
 .../Dialect/XeGPU/xegpu-recover-layout.mlir   | 150 +++++++++++++++++
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  33 ++++
 3 files changed, 251 insertions(+), 88 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-recover-layout.mlir

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 616fee1db2107..7d48315eec6ff 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -89,54 +89,29 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
 static void walkRegionBackward(Region &region,
                                llvm::function_ref<void(Operation *)> visit) {
 
-  // blocks: back -> front
-  for (Block &block : llvm::reverse(region)) {
-    // ops: back -> front, early-inc so visit() may erase current op safely
-    for (Operation &op : llvm::reverse(block)) {
+  // Use post-order traversal to process blocks in reverse topological order.
+  // This ensures that use blocks are visited before def blocks, which is
+  // required for backward layout propagation.
+  if (region.empty())
+    return;
+  llvm::ReversePostOrderTraversal<Region *> rpot(&region);
+  SmallVector<Block *> blocks(rpot.begin(), rpot.end());
+  for (Block *block : llvm::reverse(blocks)) {
+    // ops: back -> front
+    for (Operation &op : llvm::reverse(*block)) {
       // make sure we first visit inside the region op (so yield op first)
       // and then move to region op itself
-      for (Region &nested : llvm::reverse(op.getRegions()))
+      // Regions are iterated in forward order so that for multi-region ops
+      // like scf.while, earlier regions (e.g., "before/cond") are processed
+      // first. This ensures that when a later region's terminator (e.g., "do"
+      // yield) needs the layout of an earlier region's block args, those
+      // layouts are already available from use points.
+      for (Region &nested : op.getRegions())
         walkRegionBackward(nested, visit);
 
       visit(&op);
     }
   }
-  // // Process blocks in post-order (reverse topological order: dominated
-  // before dominators)
-  // // For single-block regions, this is just that block
-  // // For control flow, this ensures dominated blocks are processed before
-  // dominators
-
-  // // Step 1: Get reverse post-order traversal
-  // llvm::ReversePostOrderTraversal<Region *> rpot(&region);
-  // // Step 2: Collect into vector
-  // llvm::SmallVector<Block *> blocks(rpot.begin(), rpot.end());
-  // // Step 3: Reverse it to get post-order (reverse topological order)
-  // for (Block *block : llvm::reverse(blocks)) {
-  //   // ops: back -> front, with early-inc so visit() may erase current op
-  //   safely
-  //   // We need to collect operations first because nested region walks might
-  //   modify the block llvm::SmallVector<Operation *> ops; for (Operation &op :
-  //   llvm::reverse(*block))
-  //     ops.push_back(&op);
-
-  //   for (Operation *op : ops) {
-  //     // Check if op is still alive (might have been erased by nested walk)
-  //     if (!op->isRegistered())
-  //       continue;
-
-  //     // make sure we first visit inside the region op (so yield op first)
-  //     // and then move to region op itself
-  //     // Note: Region iteration order doesn't affect correctness since each
-  //     // region is processed in reverse topological order independently
-  //     for (Region &nested : op->getRegions())
-  //       walkRegionBackward(nested, visit);
-
-  //     // Check again if op is still alive before visiting
-  //     if (op->isRegistered())
-  //       visit(op);  // Can safely erase op now
-  //   }
-  // }
 }
 
 static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
@@ -184,8 +159,12 @@ static void propagateResultsToRegularOperands(Operation *op) {
   }
 }
 
-// propagate layout from region results to yield operands. This set the
-// temproary layout for reguion results and yield operands.
+// Propagate layout from region op results and sibling region block args
+// to yield/condition operands. For each successor of this terminator:
+// - Parent successor: propagate from parent op's result layouts (use points).
+// - Region successor: propagate from target region's block arg layouts (use
+//   points), e.g., scf.yield in "after/do" region propagates to "before/cond"
+//   block args.
 static void propagateRegionResultsToYieldOperands(
     mlir::RegionBranchTerminatorOpInterface yieldOp) {
   auto regionBranchOp =
@@ -193,57 +172,42 @@ static void propagateRegionResultsToYieldOperands(
   if (!regionBranchOp)
     return;
 
-  // Gather layouts for each result of the parent region op from external
-  // use points.
-  unsigned numResults = regionBranchOp->getNumResults();
-  if (numResults == 0)
-    return;
-
-  SmallVector<xegpu::DistributeLayoutAttr> resultLayouts(numResults, nullptr);
-  for (unsigned i = 0; i < numResults; ++i) {
-    OpResult result = regionBranchOp->getResult(i);
-    resultLayouts[i] = getLayoutFromUsePoints(result);
-    if (resultLayouts[i])
-      xegpu::setTemporaryLayout(result, resultLayouts[i]);
-  }
-
-  // We are interested in the parent successor, i.e., the branch that exits
-  // the region and forwards operands to the parent op's results. This handles
-  // index offsets automatically (e.g., scf.condition's predicate at operand #0
-  // is excluded).
-  // SmallVector<RegionSuccessor> successors;
-  // SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
-  // yieldOp.getSuccessorRegions(operandAttrs, successors);
-  // auto *parentSuccessor = llvm::find_if(
-  //     successors, [](const RegionSuccessor &s) { return s.isParent(); });
-  // assert(parentSuccessor != successors.end() &&
-  //        "terminator must have the parent op as a successor");
-
-  // OperandRange succOps = yieldOp.getSuccessorOperands(*parentSuccessor);
-  // unsigned beginIdx = succOps.getBeginOperandIndex();
-  // unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
-
   SmallVector<RegionSuccessor> successors;
   SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
   yieldOp.getSuccessorRegions(operandAttrs, successors);
-  assert(!successors.empty() && "terminator must have at least one successor");
-
-  OperandRange succOps = yieldOp.getSuccessorOperands(successors.front());
-  unsigned beginIdx = succOps.getBeginOperandIndex();
-  unsigned count = std::min(static_cast<unsigned>(succOps.size()), numResults);
 
-  for (unsigned i = 0; i < count; ++i) {
-    if (!resultLayouts[i])
+  for (const RegionSuccessor &successor : successors) {
+    OperandRange succOps = yieldOp.getSuccessorOperands(successor);
+    if (succOps.empty())
       continue;
-    xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i),
-                              resultLayouts[i]);
+    unsigned beginIdx = succOps.getBeginOperandIndex();
+    ValueRange successorInputs = regionBranchOp.getSuccessorInputs(successor);
+    unsigned count = std::min<unsigned>(succOps.size(), successorInputs.size());
+
+    for (unsigned i = 0; i < count; ++i) {
+      xegpu::DistributeLayoutAttr layout;
+      if (successor.isParent()) {
+        // For parent successor, get layout from external use points of the
+        // parent op's results.
+        layout = getLayoutFromUsePoints(regionBranchOp->getResult(i));
+        if (layout)
+          xegpu::setTemporaryLayout(regionBranchOp->getResult(i), layout);
+      } else {
+        // For region successor, get layout from the target region's block
+        // arg use points (e.g., "before/cond" region args for scf.while
+        // "after/do" yield).
+        layout = getLayoutFromUsePoints(successorInputs[i]);
+      }
+      if (!layout)
+        continue;
+      if (isa<VectorType>(succOps[i].getType()))
+        xegpu::setTemporaryLayout(yieldOp->getOpOperand(beginIdx + i), layout);
+    }
   }
 }
 
-// propagate layout from region arguments to region op's init operands. This set
-// the temproary layout for region arguments and init operands.
-// For while op containing multipel regions, different segements of init
-// operands might mapped to diferent region arguments.
+// Propagate layout from region arguments to region op's init operands. This
+// sets the temporary layout for region arguments and init operands.
 static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
   // Iterate all regions of the region op. For each block argument that has a
   // layout (determined from its use points), trace back to find the
@@ -252,14 +216,30 @@ static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface regionOp) {
   // RegionBranchOpInterface ops.
   for (Region &region : regionOp->getRegions()) {
     RegionSuccessor regionSuccessor(&region);
-    for (auto [argIdx, regionArg] : llvm::enumerate(region.getArguments())) {
+    // Use getSuccessorInputs to get the block arguments that correspond to
+    // predecessor operands. This correctly handles ops like scf.for where
+    // the induction variable is a block arg but not a successor input.
+    ValueRange successorInputs = regionOp.getSuccessorInputs(regionSuccessor);
+    for (auto [inputIdx, regionArg] : llvm::enumerate(successorInputs)) {
       auto layout = getLayoutFromUsePoints(regionArg);
       if (!layout)
         continue;
 
+      // Recover layout for tensor_desc block args by updating the type.
+      if (auto tensorDescTy =
+              dyn_cast<xegpu::TensorDescType>(regionArg.getType())) {
+        if (!tensorDescTy.getLayoutAttr()) {
+          auto typeWithLayout = xegpu::TensorDescType::get(
+              tensorDescTy.getContext(), tensorDescTy.getShape(),
+              tensorDescTy.getElementType(), tensorDescTy.getEncoding(),
+              layout);
+          regionArg.setType(typeWithLayout);
+        }
+      }
+
       // Find all predecessor values that flow into this block argument.
       SmallVector<Value> predValues;
-      regionOp.getPredecessorValues(regionSuccessor, argIdx - 1, predValues);
+      regionOp.getPredecessorValues(regionSuccessor, inputIdx, predValues);
       for (Value predVal : predValues) {
         // Match predecessor value to an operand of the regionOp.
         for (OpOperand &operand : regionOp->getOpOperands()) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-recover-layout.mlir b/mlir/test/Dialect/XeGPU/xegpu-recover-layout.mlir
new file mode 100644
index 0000000000000..a00d6d7bb3b14
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-recover-layout.mlir
@@ -0,0 +1,150 @@
+// RUN: mlir-opt -test-xegpu-recover-temporary-layouts -split-input-file %s | FileCheck %s
+
+// -----
+// Test scf.for: Recovery should propagate layout from the store_nd consumer
+// of the loop result back to the scf.for result, scf.yield operands, and
+// the arith.constant init value. Tensor desc types start without layouts
+// and only anchor ops (load_nd, store_nd, dpas) carry layout attributes.
+
+gpu.module @test_for {
+// CHECK-LABEL: gpu.func @for_basic
+gpu.func @for_basic(%arg0: memref<8x128xf16>, %arg1: memref<128x16xf16>, %arg2: memref<8x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %c128 = arith.constant 128 : index
+  %c16 = arith.constant 16 : index
+  // CHECK: xegpu.create_nd_tdesc
+  // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x128xf16>
+      -> !xegpu.tensor_desc<8x16xf16>
+  // CHECK: xegpu.create_nd_tdesc
+  // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<128x16xf16>
+      -> !xegpu.tensor_desc<16x16xf16>
+  // Recovery propagates layout from dpas (via store_nd) back to arith.constant.
+  // CHECK: arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+  // CHECK-SAME: dense<0.000000e+00> : vector<8x16xf32>
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xf32>
+  // CHECK: scf.for
+  %2 = scf.for %arg3 = %c0 to %c128 step %c16
+      iter_args(%arg6 = %cst) -> (vector<8x16xf32>) {
+    %4 = xegpu.load_nd %0 {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+        : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+    %5 = xegpu.load_nd %1 {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+        : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+    %6 = xegpu.dpas %4, %5, %arg6
+        {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+         layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+         layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+        : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+    // Recovery propagates layout to scf.yield vector operand.
+    // CHECK: scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    scf.yield %6
+        : vector<8x16xf32>
+  // Recovery sets layout_result_0 on the scf.for for the vector result.
+  // CHECK: layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+  }
+  // CHECK: xegpu.create_nd_tdesc
+  // CHECK-SAME: -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+  %3 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32>
+      -> !xegpu.tensor_desc<8x16xf32>
+  xegpu.store_nd %2, %3 {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.return
+}
+}
+
+// -----
+// Test scf.while: Recovery should propagate layout from the store_nd consumer
+// of the while result back through scf.condition (which branches to parent).
+// The scf.yield in the "do" region branches to the "before" region (not
+// parent), so propagateRegionResultsToYieldOperands skips it.
+
+gpu.module @test_while {
+// CHECK-LABEL: gpu.func @while_basic
+gpu.func @while_basic(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+  %c1_i32 = arith.constant 1 : i32
+  %c10_i32 = arith.constant 10 : i32
+  %c0_i32 = arith.constant 0 : i32
+  // CHECK: xegpu.create_nd_tdesc
+  // CHECK-SAME: -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+  %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32>
+      -> !xegpu.tensor_desc<256xf32>
+  %1 = xegpu.load_nd %0 {layout = #xegpu.layout<sg_layout = [16], sg_data = [16]>}
+      : !xegpu.tensor_desc<256xf32> -> vector<256xf32>
+  // CHECK: xegpu.create_nd_tdesc
+  // CHECK-SAME: -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+  %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32>
+      -> !xegpu.tensor_desc<256xf32>
+
+  // CHECK: scf.while
+  %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32)
+      : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
+    %4 = arith.cmpi slt, %arg3, %c10_i32 : i32
+    // Recovery propagates layout to scf.condition vector operand.
+    // CHECK: scf.condition
+    // CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [16], sg_data = [16]>}
+    scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
+  } do {
+  ^bb0(%arg2: vector<256xf32>, %arg3: i32):
+    xegpu.store_nd %arg2, %2 {layout = #xegpu.layout<sg_layout = [16], sg_data = [16]>}
+        : vector<256xf32>, !xegpu.tensor_desc<256xf32>
+    %4 = arith.addi %arg3, %c1_i32 : i32
+    %5 = xegpu.update_nd_offset %0, [256]
+        : !xegpu.tensor_desc<256xf32>
+    %6 = xegpu.load_nd %5 {layout = #xegpu.layout<sg_layout = [16], sg_data = [16]>}
+        : !xegpu.tensor_desc<256xf32> -> vector<256xf32>
+    // Recovery propagates layout to scf.yield in the "do" region via
+    // sibling region propagation (from "before" region arg back to "do" yield).
+    // CHECK: scf.yield {layout_operand_0 = #xegpu.layout<sg_layout = [16], sg_data = [16]>}
+    scf.yield %6, %4 : vector<256xf32>, i32
+  // Recovery sets layout_result_0 on the scf.while for the vector result.
+  // CHECK: } attributes {layout_operand_0 = #xegpu.layout<sg_layout = [16], sg_data = [16]>,
+  // CHECK-SAME: layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [16]>}
+  }
+  xegpu.store_nd %3#0, %2 {layout = #xegpu.layout<sg_layout = [16], sg_data = [16]>}
+      : vector<256xf32>, !xegpu.tensor_desc<256xf32>
+  gpu.return
+}
+}
+
+// -----
+// Test scf.if: Recovery should propagate layout from the dpas consumer of the
+// if result back to the scf.if result and both yield operands.
+
+gpu.module @test_if {
+// CHECK-LABEL: gpu.func @if_basic
+gpu.func @if_basic(
+    %arg0: !xegpu.tensor_desc<8x16xf16>,
+    %arg1: !xegpu.tensor_desc<16x16xf16>,
+    %arg2: i1,
+    %arg3: !xegpu.tensor_desc<8x16xf32>) {
+  %0 = xegpu.load_nd %arg0 {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+  // CHECK: scf.if
+  %1 = scf.if %arg2 -> (vector<16x16xf16>) {
+    %3 = xegpu.load_nd %arg1 {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+        : !xegpu.tensor_desc<16x16xf16>
+        -> vector<16x16xf16>
+    // Recovery propagates layout to scf.yield operand in "then" region.
+    // CHECK: scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+    scf.yield %3 : vector<16x16xf16>
+  } else {
+    %3 = xegpu.load_nd %arg1 {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+        : !xegpu.tensor_desc<16x16xf16>
+        -> vector<16x16xf16>
+    // Recovery propagates layout to scf.yield operand in "else" region.
+    // CHECK: scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+    scf.yield %3 : vector<16x16xf16>
+  // Recovery sets layout_result_0 on the scf.if for the vector result.
+  // CHECK: } {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}
+  }
+  %2 = xegpu.dpas %0, %1
+      {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+       layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>,
+       layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
+  xegpu.store_nd %2, %arg3 {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+  gpu.return
+}
+}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 4760016bdcea4..e0c21b76e722d 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -225,6 +225,38 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
   }
 };
 
+struct TestXeGPURecoverTemporaryLayouts
+    : public PassWrapper<TestXeGPURecoverTemporaryLayouts,
+                         OperationPass<gpu::GPUModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPURecoverTemporaryLayouts)
+
+  StringRef getArgument() const final {
+    return "test-xegpu-recover-temporary-layouts";
+  }
+
+  StringRef getDescription() const final {
+    return "Test the implementation of XeGPU temporary layout recovery";
+  }
+
+  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect>();
+    registry.insert<memref::MemRefDialect>();
+    registry.insert<xegpu::XeGPUDialect>();
+    registry.insert<vector::VectorDialect>();
+    registry.insert<gpu::GPUDialect>();
+  }
+
+  TestXeGPURecoverTemporaryLayouts() = default;
+  TestXeGPURecoverTemporaryLayouts(const TestXeGPURecoverTemporaryLayouts &pass)
+      : PassWrapper(pass) {}
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    if (!xegpu::recoverTemporaryLayouts(op))
+      signalPassFailure();
+  }
+};
+
 struct TestXeGPUSGDistribute
     : public PassWrapper<TestXeGPUSGDistribute,
                          OperationPass<gpu::GPUModuleOp>> {
@@ -479,6 +511,7 @@ namespace test {
 void registerTestXeGPULowerings() {
   PassRegistration<TestXeGPUUnrollingPatterns>();
   PassRegistration<TestXeGPULayoutInterface>();
+  PassRegistration<TestXeGPURecoverTemporaryLayouts>();
   PassRegistration<TestXeGPUSGDistribute>();
   PassRegistration<TestXeGPUSgToWiDistributeExperimental>();
   PassRegistration<TestXeGPUMoveFuncBodyToWarpOp>();



More information about the Mlir-commits mailing list