[Mlir-commits] [mlir] [MLIR][XeGPU] Refactor layout propagation utilities (PR #179016)

Jianhui Li llvmlistbot at llvm.org
Tue Feb 3 21:11:49 PST 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/179016

>From 20d44c2ed8d34a6f90426f0b5984732c1912ecf9 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 18 Dec 2025 17:46:38 +0000
Subject: [PATCH 01/35] add layout utitilies interface

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    | 131 +++++++
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  64 +---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  41 +-
 mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt   |   1 +
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 350 ++++++++++++++++++
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 304 ---------------
 6 files changed, 504 insertions(+), 387 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
 create mode 100644 mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
new file mode 100644
index 0000000000000..6f528214f538b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -0,0 +1,131 @@
+//===- XeGPULayoutUtils.h - Layout Utilities --------------------------*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
+#define MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
+
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+namespace mlir {
+
+class VectorType;
+class OpOperand;
+class OpResult;
+class OpBuilder;
+class ValueRange;
+class TypeConverter;
+class OpFoldResult;
+
+namespace xegpu {
+class DistributeLayoutAttr;
+class LayoutAttr;
+class TensorDescType;
+} // namespace xegpu
+
+namespace xegpu {
+
+/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
+std::string getTemporaryLayoutName(const OpOperand &operand);
+
+/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
+std::string getTemporaryLayoutName(const OpResult result);
+
+/// Retrieves the DistributeLayoutAttr associated with a given Value. For
+/// TensorDescType values, the DistributeLayoutAttr is extracted from the
+/// TensorDescType itself. For other values, it is obtained from the attributes
+/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
+/// found.
+DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
+
+/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
+/// will first check the operand_layout_{id} of the owner operation. If not
+/// found, it will check the operand itself and its defining op.
+DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
+
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpResult &Result,
+                             const DistributeLayoutAttr layout);
+
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpOperand &opr,
+                             const DistributeLayoutAttr layout);
+
+/// get and set distribute layout attribute for non-anchor operations
+/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
+template <typename T,
+          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                      std::is_same_v<T, OpResult>>>
+DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult);
+
+template <typename T,
+          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                      std::is_same_v<T, OpResult>>>
+void setTemporaryLayout(const T &operandOrResult,
+                        const DistributeLayoutAttr layout);
+
+/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
+/// OpResult of of the given operation. If the operation contains regions, it is
+/// also applied recursively to the contained operations operation.
+/// TODO: To be replaced by recoverTemporaryLayouts()
+void recoverTemporaryLayoutsDeprecated(Operation *op);
+
+/// Attach layout attributes to all vector-type operands of operations within
+/// the given operation's region. Reports an error if any vector operand lacks
+/// a layout attribute.
+bool recoverTemporaryLayouts(Operation *rootOp);
+
+/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
+template <typename T,
+          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                      std::is_same_v<T, OpResult>>>
+void removeLayoutAttr(const T &operandOrResult);
+
+/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
+/// given operation if they exist. If the operation contains regions, it is also
+/// applied recursively to the contained operations
+void removeLayoutAttrs(Operation *op);
+
+/// Infers the source layout attribute for a broadcast operation given the
+/// result layout attribute, result shape, source shape, and broadcasted dims.
+DistributeLayoutAttr inferBroadCastSourceLayout(MLIRContext *context,
+                                                DistributeLayoutAttr resLayout,
+                                                ArrayRef<int64_t> resShape,
+                                                ArrayRef<int64_t> srcShape);
+
+/// Infers the source layout attribute for a reduction operation given the
+/// result layout attribute, result shape, source shape, and reduced dims.
+DistributeLayoutAttr
+inferReductionSourceLayout(MLIRContext *context, DistributeLayoutAttr resLayout,
+                           ArrayRef<int64_t> resShape,
+                           ArrayRef<int64_t> srcShape,
+                           SmallVector<int64_t> reduceDims);
+
+/// Infers the source layout attribute for a bitcast operation given the
+/// result layout attribute, result element type bitwidth, and source element
+/// type bitwidth.
+DistributeLayoutAttr inferBitCastSourceLayout(MLIRContext *context,
+                                              DistributeLayoutAttr resLayout,
+                                              int resElemTyBitWidth,
+                                              int srcElemTyBitWidth);
+
+/// Infers the source layout attribute for a shape cast operation given the
+/// result layout attribute, result shape, and source shape.
+DistributeLayoutAttr inferShapeCastSourceLayout(MLIRContext *context,
+                                                DistributeLayoutAttr resLayout,
+                                                ArrayRef<int64_t> resShape,
+                                                ArrayRef<int64_t> srcShape);
+
+} // namespace xegpu
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 46d52516cbc15..3dbbe7e4c5dff 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
 #define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
 
+#include "XeGPULayoutUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
@@ -118,69 +119,6 @@ template <typename T>
 int getLargestDivisor(T dim, ArrayRef<T> candidates,
                       ArrayRef<T> candidateMultiples = {});
 
-/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
-std::string getTemporaryLayoutName(const OpOperand &operand);
-
-/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
-std::string getTemporaryLayoutName(const OpResult result);
-
-/// Retrieves the DistributeLayoutAttr associated with a given Value. For
-/// TensorDescType values, the DistributeLayoutAttr is extracted from the
-/// TensorDescType itself. For other values, it is obtained from the attributes
-/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
-/// found.
-DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
-
-/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
-/// will first check the operand_layout_{id} of the owner operation. If not
-/// found, it will check the operand itself and its defining op.
-DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
-
-/// Removes the LayoutAttr for a given OpOperand or OpResult if it exists.
-template <typename T,
-          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
-                                      std::is_same_v<T, OpResult>>>
-void removeLayoutAttr(const T &operandOrResult);
-
-/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
-/// given operation if they exist. If the operation contains regions, it is also
-/// applied recursively to the contained operations
-void removeLayoutAttrs(Operation *op);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpResult &Result,
-                             const DistributeLayoutAttr layout);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpOperand &opr,
-                             const DistributeLayoutAttr layout);
-
-/// get and set distribute layout attribute for non-anchor operations
-/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
-template <typename T,
-          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
-                                      std::is_same_v<T, OpResult>>>
-DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult);
-
-template <typename T,
-          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
-                                      std::is_same_v<T, OpResult>>>
-void setTemporaryLayout(const T &operandOrResult,
-                        const DistributeLayoutAttr layout);
-
-/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
-/// OpResult of of the given operation. If the operation contains regions, it is
-/// also applied recursively to the contained operations operation.
-/// TODO: To be replaced by recoverTemporaryLayouts()
-void recoverTemporaryLayoutsDeprecated(Operation *op);
-
-/// Attach layout attributes to all vector-type operands of operations within
-/// the given operation's region. Reports an error if any vector operand lacks
-/// a layout attribute.
-bool recoverTemporaryLayouts(Operation *rootOp);
-
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7fc75e7294ea3..60dedf373a07b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -601,8 +601,21 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   LayoutInfo resultLayout = results[0]->getValue();
   if (!resultLayout.isAssigned())
     return;
-  // We only consider 2D -> 1D reductions at this point.
+
   VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
+  VectorType sourceTy =
+      llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
+  SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
+                                     reduction.getReductionDims().end());
+  // xegpu::DistributeLayoutAttr operandLayout =
+  // xegpu::inferReductionSourceLayout(
+  //       reduction.getContext(),
+  //       dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
+  //       resultTy.getShape(),
+  //       sourceTy.getShape(),
+  //       reductionDims);
+  // We only consider 2D -> 1D reductions at this point.
+
   if (!resultTy || resultTy.getRank() != 1) {
     reduction.emitWarning("Expecting output type to be 1D vector.");
     return;
@@ -633,25 +646,13 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
 
   // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
   if (sourceTy.getRank() != resultTy.getRank()) {
-    auto sourceDims = sourceTy.getShape();
-    auto resultDims = resultTy.getShape();
-    SmallVector<int64_t> bcastDims;
-    auto dimDiff = resultTy.getRank() - sourceTy.getRank();
-    // adding the missing leading dims
-    for (int i = 0; i < dimDiff; i++)
-      bcastDims.push_back(i);
-
-    // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
-    // broadcasted dim
-    for (size_t i = 0; i < sourceDims.size(); i++)
-      if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
-        bcastDims.push_back(i + dimDiff);
-
-    // create a slice layout for the source
-    xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
-        broadcast->getContext(),
-        cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
-        DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
+    auto srcShape = sourceTy.getShape();
+    auto resShape = resultTy.getShape();
+    auto resultLayoutAttr =
+        dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+
+    xegpu::DistributeLayoutAttr sliceLayout = xegpu::inferBroadCastSourceLayout(
+        broadcast.getContext(), resultLayoutAttr, resShape, srcShape);
 
     propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
     return;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index d9bf4a1461c27..bde8324aab5fb 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRXeGPUUtils
   XeGPUUtils.cpp
+  XeGPULayoutUtils.cpp  
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/Utils
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
new file mode 100644
index 0000000000000..2f68aa7bda924
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -0,0 +1,350 @@
+//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps   ------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements utility methods for working with the XeGPU dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <numeric>
+
+using namespace mlir;
+
+std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
+  const StringRef prefix("layout_operand_");
+  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+  return llvm::formatv("{0}{1}", prefix, idx).str();
+}
+
+std::string xegpu::getTemporaryLayoutName(const OpResult result) {
+  const StringRef prefix = "layout_result_";
+  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+}
+
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
+  if (!value)
+    return nullptr;
+
+  if (auto tdescTy =
+          dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
+    return tdescTy.getLayoutAttr();
+
+  if (auto result = dyn_cast<OpResult>(value)) {
+    Operation *defOp = result.getDefiningOp();
+    assert(defOp && "result must have a defining op");
+
+    if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
+      auto layout = anchorOp.getAnchorLayout();
+      return layout;
+    }
+
+    std::string layoutName = getTemporaryLayoutName(result);
+    if (defOp->hasAttr(layoutName)) {
+      auto layout =
+          defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+      return layout;
+    }
+  }
+
+  if (auto arg = dyn_cast<BlockArgument>(value)) {
+    auto *parentOp = arg.getOwner()->getParentOp();
+    if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+      OpOperand *tiedInit = loop.getTiedLoopInit(arg);
+      if (tiedInit)
+        return getDistributeLayoutAttr(tiedInit->get());
+    }
+  }
+
+  return nullptr;
+}
+xegpu::DistributeLayoutAttr
+xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
+  Operation *op = opr.getOwner();
+  unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
+
+  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+    if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
+      if (idx == 0) {
+        return dpasOp.getLayoutAAttr();
+      } else if (idx == 1) {
+        return dpasOp.getLayoutBAttr();
+      } else if (idx == 2) {
+        return dpasOp.getLayoutCdAttr();
+      }
+    }
+    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
+      return convertOp.getInputLayoutAttr();
+    }
+    auto layout = anchorOp.getAnchorLayout();
+
+    if (idx == 0)
+      return layout;
+
+    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+    // the layout is valid for the first two operands: value and memref/tdesc.
+    // For other operations, the layout applies to the first operand only.
+    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
+            op) &&
+        (idx < 2))
+      return layout;
+  }
+
+  std::string layoutName = xegpu::getTemporaryLayoutName(opr);
+  if (op->hasAttr(layoutName)) {
+    auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+    return layout;
+  }
+
+  auto layout = getDistributeLayoutAttr(opr.get());
+  return layout;
+}
+
+// TODO-LayoutRefactor: Remove this function after replacing use
+//  with setTemporaryLayout or setAnchorLayout
+void xegpu::setDistributeLayoutAttr(
+    const mlir::OpResult &result,
+    const mlir::xegpu::DistributeLayoutAttr layout) {
+  Operation *owner = result.getOwner();
+
+  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
+    if (anchorOp.getAnchorLayout() == layout)
+      return;
+    anchorOp.setAnchorLayout(layout);
+    return;
+  }
+
+  std::string name = xegpu::getTemporaryLayoutName(result);
+  if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
+    return;
+  }
+  if (layout) {
+    owner->setAttr(name, layout);
+  }
+}
+
+// TODO-LayoutRefactor: Remove this function after replacing use
+//  with setTemporaryLayout or setAnchorLayout
+void xegpu::setDistributeLayoutAttr(const OpOperand &operand,
+                                    const DistributeLayoutAttr layout) {
+  Operation *owner = operand.getOwner();
+  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+
+  if (!layout) {
+    return;
+  }
+  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
+    if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
+      if (idx == 0) {
+        return dpasOp.setLayoutAAttr(layout);
+      } else if (idx == 1) {
+        return dpasOp.setLayoutBAttr(layout);
+      } else if (idx == 2) {
+        return dpasOp.setLayoutCdAttr(layout);
+      }
+    }
+    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
+      return convertOp.setInputLayoutAttr(layout);
+    }
+
+    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+    // the layout is valid for the first two operands: value and memref/tdesc.
+    // For other operations, the layout applies to the first operand only.
+    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
+            owner)) {
+      if (idx < 2) {
+        anchorOp.setAnchorLayout(layout);
+      }
+    } else {
+      if (idx == 0) {
+        anchorOp.setAnchorLayout(layout);
+      }
+    }
+  }
+
+  std::string name = xegpu::getTemporaryLayoutName(operand);
+  if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
+    return;
+  }
+  if (layout) {
+    owner->setAttr(name, layout);
+  }
+}
+
+template <typename T, typename>
+xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout(const T &operandOrResult) {
+  Operation *op = operandOrResult.getOwner();
+
+  std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
+  if (op->hasAttr(layoutName)) {
+    auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+    return layout;
+  }
+
+  return nullptr;
+}
+
+template xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout<mlir::OpResult>(const OpResult &result);
+template xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout<mlir::OpOperand>(const OpOperand &operand);
+
+template <typename T, typename>
+void xegpu::setTemporaryLayout(const T &operandOrResult,
+                               const xegpu::DistributeLayoutAttr layout) {
+  Operation *owner = operandOrResult.getOwner();
+  std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
+  if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
+    return;
+  }
+  if (layout) {
+    owner->setAttr(name, layout);
+  }
+}
+
+template void xegpu::setTemporaryLayout<mlir::OpResult>(
+    const mlir::OpResult &result,
+    const mlir::xegpu::DistributeLayoutAttr layout);
+
+template void xegpu::setTemporaryLayout<mlir::OpOperand>(
+    const mlir::OpOperand &operand,
+    const mlir::xegpu::DistributeLayoutAttr layout);
+
+void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
+  op->walk([&](Operation *nestOp) {
+    for (OpOperand &opr : nestOp->getOpOperands()) {
+      auto layout = getDistributeLayoutAttr(opr.get());
+      setDistributeLayoutAttr(opr, layout);
+    }
+
+    for (OpResult result : nestOp->getOpResults()) {
+      auto layout = getDistributeLayoutAttr(result);
+      setDistributeLayoutAttr(result, layout);
+    }
+  });
+}
+
+/// Attach layout attributes to all vector-type operands of operations within
+/// the given operation's region. Reports an error if any vector operand lacks
+/// a layout attribute.
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+  auto result = rootOp->walk([&](Operation *op) {
+    for (OpOperand &operand : op->getOpOperands()) {
+      // Layouts are needed for vector type only.
+      if (!isa<VectorType>(operand.get().getType()))
+        continue;
+      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+      if (!layout) {
+        op->emitError("Could not find layout attribute for operand ")
+            << operand.getOperandNumber() << " of operation " << op->getName();
+        return WalkResult::interrupt();
+      }
+      xegpu::setDistributeLayoutAttr(operand, layout);
+    }
+    return WalkResult::advance();
+  });
+  return !result.wasInterrupted();
+}
+
+template <typename T, typename>
+void xegpu::removeLayoutAttr(const T &operandOrResult) {
+  Operation *owner = operandOrResult.getOwner();
+  std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
+  if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
+    owner->removeAttr(name);
+}
+
+// Explicit instantiation for OpResult
+template void
+xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
+
+// Explicit instantiation for OpOperand
+template void
+xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
+
+void xegpu::removeLayoutAttrs(Operation *op) {
+  op->walk([&](Operation *nestOp) {
+    for (OpOperand &opr : nestOp->getOpOperands())
+      removeLayoutAttr(opr);
+    for (OpResult result : nestOp->getOpResults())
+      removeLayoutAttr(result);
+    if (op->hasAttrOfType<DistributeLayoutAttr>("layout"))
+      op->removeAttr("layout");
+    if (op->hasAttrOfType<DistributeLayoutAttr>("layout_a"))
+      op->removeAttr("layout_a");
+    if (op->hasAttrOfType<DistributeLayoutAttr>("layout_b"))
+      op->removeAttr("layout_b");
+    if (op->hasAttrOfType<DistributeLayoutAttr>("layout_cd"))
+      op->removeAttr("layout_cd");
+  });
+}
+
+/// Infers the source layout attribute for a broadcast operation given the
+/// result layout attribute, result shape, source shape, and broadcasted dims.
+xegpu::DistributeLayoutAttr xegpu::inferBroadCastSourceLayout(
+    MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
+    ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
+
+  SmallVector<int64_t> bcastDims;
+  int dimDiff = resShape.size() - srcShape.size();
+  // adding the missing leading dims
+  for (int i = 0; i < dimDiff; i++)
+    bcastDims.push_back(i);
+
+  // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
+  // broadcasted dim
+  // for (size_t i = 0; i < srcShape.size(); i++)
+  //   if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
+  //     bcastDims.push_back(i + dimDiff);
+
+  // create a slice layout for the source
+  xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+      context, resLayout, DenseI64ArrayAttr::get(context, bcastDims));
+
+  return sliceLayout;
+}
+
+// /// Infers the source layout attribute for a reduction operation given the
+// /// result layout attribute, result shape, source shape, and reduced dims.
+// xegpu::DistributeLayoutAttr xegpu::inferReductionSourceLayout(
+//         MLIRContext *context,
+//         xegpu::DistributeLayoutAttr resLayout,
+//         SmallVector<int64_t> reduceDims){
+
+//    //  flatten the reslayout first
+//    xegpu::DistributeLayoutAttr flatResLayout =
+//        resLayout.flatten();
+//   //  then unslice the reduceDims in the flattened layout
+
+//}
+
+/// Infers the source layout attribute for a bitcast operation given the
+/// result layout attribute, result element type bitwidth, and source element
+/// type bitwidth.
+xegpu::DistributeLayoutAttr
+xegpu::inferBitCastSourceLayout(MLIRContext *context,
+                                xegpu::DistributeLayoutAttr resLayout,
+                                int resElemTyBitWidth, int srcElemTyBitWidth);
+
+/// Infers the source layout attribute for a shape cast operation given the
+/// result layout attribute, result shape, and source shape.
+xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
+    MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
+    ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index d3906e37ffbf1..181b7e9673fef 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -101,310 +101,6 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
   return xegpu::getDistributedVectorType(helperTdescTy);
 }
 
-std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
-  const StringRef prefix("layout_operand_");
-  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-  return llvm::formatv("{0}{1}", prefix, idx).str();
-}
-
-std::string xegpu::getTemporaryLayoutName(const OpResult result) {
-  const StringRef prefix = "layout_result_";
-  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
-}
-
-xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
-  if (!value)
-    return nullptr;
-
-  if (auto tdescTy =
-          dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
-    return tdescTy.getLayoutAttr();
-
-  if (auto result = dyn_cast<OpResult>(value)) {
-    Operation *defOp = result.getDefiningOp();
-    assert(defOp && "result must have a defining op");
-
-    if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
-      auto layout = anchorOp.getAnchorLayout();
-      return layout;
-    }
-
-    std::string layoutName = getTemporaryLayoutName(result);
-    if (defOp->hasAttr(layoutName)) {
-      auto layout =
-          defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-      return layout;
-    }
-  }
-
-  if (auto arg = dyn_cast<BlockArgument>(value)) {
-    auto *parentOp = arg.getOwner()->getParentOp();
-    if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
-      OpOperand *tiedInit = loop.getTiedLoopInit(arg);
-      if (tiedInit)
-        return getDistributeLayoutAttr(tiedInit->get());
-    }
-  }
-
-  return nullptr;
-}
-xegpu::DistributeLayoutAttr
-xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
-  Operation *op = opr.getOwner();
-  unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
-
-  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
-    if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
-      if (idx == 0) {
-        return dpasOp.getLayoutAAttr();
-      } else if (idx == 1) {
-        return dpasOp.getLayoutBAttr();
-      } else if (idx == 2) {
-        return dpasOp.getLayoutCdAttr();
-      }
-    }
-    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
-      return convertOp.getInputLayoutAttr();
-    }
-    auto layout = anchorOp.getAnchorLayout();
-
-    if (idx == 0)
-      return layout;
-
-    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
-    // the layout is valid for the first two operands: value and memref/tdesc.
-    // For other operations, the layout applies to the first operand only.
-    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
-            op) &&
-        (idx < 2))
-      return layout;
-  }
-
-  std::string layoutName = xegpu::getTemporaryLayoutName(opr);
-  if (op->hasAttr(layoutName)) {
-    auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-    return layout;
-  }
-
-  auto layout = getDistributeLayoutAttr(opr.get());
-  return layout;
-}
-
-// Returns the permanent layout attribute for the given result if it's
-// available on the defining op. Otherwise returns the provided layout.
-xegpu::DistributeLayoutAttr
-maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
-                         const OpResult &result, mlir::Operation *owner,
-                         const std::string &name) {
-  xegpu::DistributeLayoutAttr candidate = layout;
-
-  if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
-    if (auto perm = loadOp.getLayoutAttr())
-      candidate = perm;
-  }
-
-  return candidate;
-}
-
-// Returns the permanent layout attribute for the given operand if it's
-// available on the defining op. Otherwise returns the provided layout.
-xegpu::DistributeLayoutAttr
-maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout,
-                         const OpOperand &operand, mlir::Operation *owner,
-                         const std::string &name) {
-  xegpu::DistributeLayoutAttr candidate = layout;
-  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-
-  if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
-    if (idx == 0) {
-      if (auto perm = storeOp.getLayoutAttr())
-        candidate = perm;
-    }
-  }
-
-  return candidate;
-}
-
-// TODO-LayoutRefactor: Remove this function after replacing use
-//  with setTemporaryLayout or setAnchorLayout
-void xegpu::setDistributeLayoutAttr(
-    const mlir::OpResult &result,
-    const mlir::xegpu::DistributeLayoutAttr layout) {
-  Operation *owner = result.getOwner();
-
-  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
-    if (anchorOp.getAnchorLayout() == layout)
-      return;
-    anchorOp.setAnchorLayout(layout);
-    return;
-  }
-
-  std::string name = xegpu::getTemporaryLayoutName(result);
-  if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
-    return;
-  }
-  if (layout) {
-    owner->setAttr(name, layout);
-  }
-}
-
-// TODO-LayoutRefactor: Remove this function after replacing use
-//  with setTemporaryLayout or setAnchorLayout
-void xegpu::setDistributeLayoutAttr(const OpOperand &operand,
-                                    const DistributeLayoutAttr layout) {
-  Operation *owner = operand.getOwner();
-  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-
-  if (!layout) {
-    return;
-  }
-  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
-    if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
-      if (idx == 0) {
-        return dpasOp.setLayoutAAttr(layout);
-      } else if (idx == 1) {
-        return dpasOp.setLayoutBAttr(layout);
-      } else if (idx == 2) {
-        return dpasOp.setLayoutCdAttr(layout);
-      }
-    }
-    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
-      return convertOp.setInputLayoutAttr(layout);
-    }
-
-    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
-    // the layout is valid for the first two operands: value and memref/tdesc.
-    // For other operations, the layout applies to the first operand only.
-    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
-            owner)) {
-      if (idx < 2) {
-        anchorOp.setAnchorLayout(layout);
-      }
-    } else {
-      if (idx == 0) {
-        anchorOp.setAnchorLayout(layout);
-      }
-    }
-  }
-
-  std::string name = xegpu::getTemporaryLayoutName(operand);
-  if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
-    return;
-  }
-  if (layout) {
-    owner->setAttr(name, layout);
-  }
-}
-
-template <typename T, typename>
-xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout(const T &operandOrResult) {
-  Operation *op = operandOrResult.getOwner();
-
-  std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
-  if (op->hasAttr(layoutName)) {
-    auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-    return layout;
-  }
-
-  return nullptr;
-}
-
-template xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout<mlir::OpResult>(const OpResult &result);
-template xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout<mlir::OpOperand>(const OpOperand &operand);
-
-template <typename T, typename>
-void xegpu::setTemporaryLayout(const T &operandOrResult,
-                               const xegpu::DistributeLayoutAttr layout) {
-  Operation *owner = operandOrResult.getOwner();
-  std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
-  if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
-    return;
-  }
-  if (layout) {
-    owner->setAttr(name, layout);
-  }
-}
-
-template void xegpu::setTemporaryLayout<mlir::OpResult>(
-    const mlir::OpResult &result,
-    const mlir::xegpu::DistributeLayoutAttr layout);
-
-template void xegpu::setTemporaryLayout<mlir::OpOperand>(
-    const mlir::OpOperand &operand,
-    const mlir::xegpu::DistributeLayoutAttr layout);
-
-void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
-  op->walk([&](Operation *nestOp) {
-    for (OpOperand &opr : nestOp->getOpOperands()) {
-      auto layout = getDistributeLayoutAttr(opr.get());
-      setDistributeLayoutAttr(opr, layout);
-    }
-
-    for (OpResult result : nestOp->getOpResults()) {
-      auto layout = getDistributeLayoutAttr(result);
-      setDistributeLayoutAttr(result, layout);
-    }
-  });
-}
-
-/// Attach layout attributes to all vector-type operands of operations within
-/// the given operation's region. Reports an error if any vector operand lacks
-/// a layout attribute.
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-  auto result = rootOp->walk([&](Operation *op) {
-    for (OpOperand &operand : op->getOpOperands()) {
-      // Layouts are needed for vector type only.
-      if (!isa<VectorType>(operand.get().getType()))
-        continue;
-      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
-      if (!layout) {
-        op->emitError("Could not find layout attribute for operand ")
-            << operand.getOperandNumber() << " of operation " << op->getName();
-        return WalkResult::interrupt();
-      }
-      xegpu::setDistributeLayoutAttr(operand, layout);
-    }
-    return WalkResult::advance();
-  });
-  return !result.wasInterrupted();
-}
-
-template <typename T, typename>
-void xegpu::removeLayoutAttr(const T &operandOrResult) {
-  Operation *owner = operandOrResult.getOwner();
-  std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
-  if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
-    owner->removeAttr(name);
-}
-
-// Explicit instantiation for OpResult
-template void
-xegpu::removeLayoutAttr<mlir::OpResult>(const mlir::OpResult &result);
-
-// Explicit instantiation for OpOperand
-template void
-xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
-
-void xegpu::removeLayoutAttrs(Operation *op) {
-  op->walk([&](Operation *nestOp) {
-    for (OpOperand &opr : nestOp->getOpOperands())
-      removeLayoutAttr(opr);
-    for (OpResult result : nestOp->getOpResults())
-      removeLayoutAttr(result);
-    if (op->hasAttrOfType<DistributeLayoutAttr>("layout"))
-      op->removeAttr("layout");
-    if (op->hasAttrOfType<DistributeLayoutAttr>("layout_a"))
-      op->removeAttr("layout_a");
-    if (op->hasAttrOfType<DistributeLayoutAttr>("layout_b"))
-      op->removeAttr("layout_b");
-    if (op->hasAttrOfType<DistributeLayoutAttr>("layout_cd"))
-      op->removeAttr("layout_cd");
-  });
-}
-
 SmallVector<Value>
 xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
                                         Value value, ArrayRef<int64_t> shape) {

>From a45ca448e6900a024e2d22d2724af17c90ce9243 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 23 Dec 2025 23:56:39 +0000
Subject: [PATCH 02/35] add layout set up rule for reduction

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  24 +-
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    |  14 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 148 +++++--
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  98 +++--
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 412 ++++++++++++++++--
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |   4 +-
 6 files changed, 584 insertions(+), 116 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 446f64fffa468..eae2cbe24a72c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -236,6 +236,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     "FailureOr<SmallVector<Value>>",
                     "delinearizeId",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
+    InterfaceMethod<[{Derive a new layout with sg_data, inst_data and lane_data set to the 
+                      specified values for the given dimension}],
+                    "xegpu::DistributeLayoutAttr",
+                    "setDimData",
+                    (ins "int64_t": $dim,
+                          "int64_t": $sgData,
+                          "int64_t": $instData,
+                          "int64_t": $laneData)>,                                                 
     InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
                       assigned to a level identified by linearId. The shape parameter
                       represents the higher-level problem size. Each level may access
@@ -501,10 +509,14 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     }
 
     //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
-    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
 
     //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
+
+    // Derive a new layout with sg_data, inst_data and lane_data set to the 
+    // specified values for the given dimension
+    DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
     /// Delinearizes a linear ID into its multidimensional indices
     /// based on the effective level of the layout.
@@ -672,10 +684,14 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     }
 
     //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
-    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
 
     //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
+
+    // Derive a new layout with sg_data, inst_data and lane_data set to the 
+    // specified values for the given dimension
+    DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 6f528214f538b..01cb43b73d5ca 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -102,11 +102,9 @@ DistributeLayoutAttr inferBroadCastSourceLayout(MLIRContext *context,
                                                 ArrayRef<int64_t> srcShape);
 
 /// Infers the source layout attribute for a reduction operation given the
-/// result layout attribute, result shape, source shape, and reduced dims.
+/// result layout attribute and reduced dims.
 DistributeLayoutAttr
-inferReductionSourceLayout(MLIRContext *context, DistributeLayoutAttr resLayout,
-                           ArrayRef<int64_t> resShape,
-                           ArrayRef<int64_t> srcShape,
+inferReductionSourceLayout(DistributeLayoutAttr resLayout,
                            SmallVector<int64_t> reduceDims);
 
 /// Infers the source layout attribute for a bitcast operation given the
@@ -124,6 +122,14 @@ DistributeLayoutAttr inferShapeCastSourceLayout(MLIRContext *context,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
 
+/// Sets the the layout attribute for result based on a preferred Layout
+/// propagated from consumer
+/// the ouput must be a slice attribute
+SliceAttr
+reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
+                         SmallVector<int64_t> reductionDims,
+                         DistributeLayoutAttr consumerPreferredLayout);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index ccf17da26c942..0ecfe3eac650c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -400,7 +400,8 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
 }
 
 // set the layout for unit dims: sg_data, inst_data and lane_data to 1
-DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+DistributeLayoutAttr
+LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
   auto sgDataOpt = getSgData();
   auto instDataOpt = getInstData();
   auto laneDataOpt = getLaneData();
@@ -409,15 +410,14 @@ DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
   SmallVector<int32_t> instData;
   SmallVector<int32_t> laneData;
 
-  if (sgDataOpt) {
+  if (sgDataOpt)
     sgData = llvm::to_vector(sgDataOpt.asArrayRef());
-  }
-  if (instDataOpt) {
+
+  if (instDataOpt)
     instData = llvm::to_vector(instDataOpt.asArrayRef());
-  }
-  if (laneDataOpt) {
+
+  if (laneDataOpt)
     laneData = llvm::to_vector(laneDataOpt.asArrayRef());
-  }
 
   for (auto dim : unitDims) {
     if (dim < static_cast<int64_t>(sgData.size()))
@@ -441,19 +441,18 @@ DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
 }
 
 // set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+DistributeLayoutAttr
+LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
   auto sgLayoutOpt = getSgLayout();
   auto laneLayoutOpt = getLaneLayout();
 
   SmallVector<int32_t> sgLayout;
   SmallVector<int32_t> laneLayout;
 
-  if (sgLayoutOpt) {
+  if (sgLayoutOpt)
     sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
-  }
-  if (laneLayoutOpt) {
+  if (laneLayoutOpt)
     laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
-  }
 
   for (auto dim : unitDims) {
     if (dim < static_cast<int64_t>(sgLayout.size()))
@@ -472,6 +471,47 @@ DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
       getLaneData(), getOrder());
 }
 
+// Derive a new layout with sg_data, inst_data and lane_data set to the
+// specified values for the given dimension
+DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
+                                            int64_t instData,
+                                            int64_t laneData) {
+  auto sgDataOpt = getSgData();
+  auto instDataOpt = getInstData();
+  auto laneDataOpt = getLaneData();
+
+  SmallVector<int32_t> sgDataVec;
+  SmallVector<int32_t> instDataVec;
+  SmallVector<int32_t> laneDataVec;
+
+  if (sgDataOpt)
+    sgDataVec = llvm::to_vector(sgDataOpt.asArrayRef());
+
+  if (instDataOpt)
+    instDataVec = llvm::to_vector(instDataOpt.asArrayRef());
+
+  if (laneDataOpt)
+    laneDataVec = llvm::to_vector(laneDataOpt.asArrayRef());
+
+  if (dim < static_cast<int64_t>(sgDataVec.size()))
+    sgDataVec[dim] = sgData;
+  if (dim < static_cast<int64_t>(instDataVec.size()))
+    instDataVec[dim] = instData;
+  if (dim < static_cast<int64_t>(laneDataVec.size()))
+    laneDataVec[dim] = laneData;
+
+  return LayoutAttr::get(
+      getContext(), getSgLayout(),
+      sgDataVec.empty() ? DenseI32ArrayAttr()
+                        : DenseI32ArrayAttr::get(getContext(), sgDataVec),
+      instDataVec.empty() ? DenseI32ArrayAttr()
+                          : DenseI32ArrayAttr::get(getContext(), instDataVec),
+      getLaneLayout(),
+      laneDataVec.empty() ? DenseI32ArrayAttr()
+                          : DenseI32ArrayAttr::get(getContext(), laneDataVec),
+      getOrder());
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_SliceAttr
 //===----------------------------------------------------------------------===//
@@ -604,55 +644,83 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
 }
 
 // Helper function to adjust unit dimensions from sliced space to parent space
+// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
+// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
+// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
+// it maps to dim [2] in parent space.
 static SetVector<int64_t>
-adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
-                            ArrayRef<int64_t> sliceDims) {
-  // Reconstruct parent's non-sliced dimensions
-
-  int64_t parentRank = sliceDims.size() + unitDims.size();
+mapDimsFromSlicedSpace(const SetVector<int64_t> &dimsInSlice,
+                       ArrayRef<int64_t> sliceDims) {
+  // get max number from sliceDims and unitDims to determine parent space rank
+  int64_t maxDim = -1;
+  maxDim =
+      std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
+  maxDim = std::max(maxDim,
+                    *std::max_element(dimsInSlice.begin(), dimsInSlice.end()));
+  int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
+
+  // get remaining dims in parent space after applying slicing with parent's
+  // slice Dims
   llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
                                              sliceDims.end());
-  SmallVector<int64_t> nonSlicedDims;
-  for (int64_t i = 0; i < parentRank; ++i) {
+  SmallVector<int64_t> remainingDims;
+  for (int64_t i = 0; i < parentSpaceRank; ++i) {
     if (!slicedDimsSet.contains(i))
-      nonSlicedDims.push_back(i);
+      remainingDims.push_back(i);
   }
 
   // Map unit dims from sliced space to parent space
-  SetVector<int64_t> adjustUnitDims;
-  for (auto dim : unitDims) {
-    if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
-      adjustUnitDims.insert(nonSlicedDims[dim]);
-    }
+  SetVector<int64_t> dimsInUnSlicedSpace;
+  for (auto dim : dimsInSlice) {
+    int64_t mappedDim = remainingDims[dim];
+    dimsInUnSlicedSpace.insert(mappedDim);
   }
 
-  return adjustUnitDims;
+  return dimsInUnSlicedSpace;
 }
 
 // set the layout for unit dims: sg_data, inst_data and lane_data to 1
-DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
-  SliceAttr attr = flatten();
-  ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
-  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+DistributeLayoutAttr
+SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
+  DistributeLayoutAttr parentLayout = getParent();
+
+  ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
 
   SetVector<int64_t> adjustUnitDims =
-      adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+      mapDimsFromSlicedSpace(unitDims, sliceDims);
 
-  return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
-                        attr.getDims());
+  return SliceAttr::get(getContext(),
+                        parentLayout.setUnitDimData(adjustUnitDims), getDims());
 }
 
 // set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
-  SliceAttr attr = flatten();
-  ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
-  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+DistributeLayoutAttr
+SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
+  DistributeLayoutAttr parentLayout = getParent();
+
+  ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
 
   SetVector<int64_t> adjustUnitDims =
-      adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+      mapDimsFromSlicedSpace(unitDims, sliceDims);
 
-  return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
-                        attr.getDims());
+  return SliceAttr::get(
+      getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
+}
+
+// Derive a new layout with sg_data, inst_data and lane_data set to the
+// specified values for the given dimension
+DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
+                                           int64_t instData, int64_t laneData) {
+  ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
+  auto parent = dyn_cast<LayoutAttr>(getParent());
+
+  SetVector<int64_t> dimSet;
+  dimSet.insert(dim);
+  SetVector<int64_t> adjustDims = mapDimsFromSlicedSpace(dimSet, sliceDims);
+
+  return SliceAttr::get(
+      getContext(),
+      parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 60dedf373a07b..1be44480de01d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -128,6 +128,7 @@ struct LayoutInfo {
   }
 
   Attribute get() { return storage; }
+  void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
 };
 
 SmallVector<int> LayoutInfo::getLaneLayout() const {
@@ -607,25 +608,58 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
       llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
   SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
                                      reduction.getReductionDims().end());
-  // xegpu::DistributeLayoutAttr operandLayout =
-  // xegpu::inferReductionSourceLayout(
-  //       reduction.getContext(),
-  //       dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
-  //       resultTy.getShape(),
-  //       sourceTy.getShape(),
-  //       reductionDims);
-  // We only consider 2D -> 1D reductions at this point.
-
-  if (!resultTy || resultTy.getRank() != 1) {
-    reduction.emitWarning("Expecting output type to be 1D vector.");
-    return;
+
+  auto srcShape = sourceTy.getShape();
+
+  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcShape = [";
+             for (auto dim
+                  : srcShape) llvm::dbgs()
+             << dim << " ";
+             llvm::dbgs() << "]\n");
+  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: reductionDims = [";
+             for (auto dim
+                  : reductionDims) llvm::dbgs()
+             << dim << " ";
+             llvm::dbgs() << "]\n");
+
+  auto resultLayoutAttr =
+      dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+
+  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: resultLayoutAttr = "
+                    << resultLayoutAttr << "\n");
+
+  // An dominant layout is for the result and represents the layout requirements
+  // for the operation it is recorded to anchor layout or temporary layout it
+  // must be honored for current op and may conflict with the layout propagated
+  // from consumer op the conflict is resolved in later phase by converting the
+  // dominant layout to the source layout
+
+  xegpu::DistributeLayoutAttr dominantLayout = xegpu::reductionLayoutSetupRule(
+      srcShape, reductionDims, resultLayoutAttr);
+
+  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: dominantLayout = "
+                    << dominantLayout << "\n");
+
+  if (layoutKind == LayoutKind::Lane) {
+    // only lane layout/data is considered
+    dominantLayout = dominantLayout.dropInstData();
+    dominantLayout = dominantLayout.dropSgLayoutAndData();
+  } else if (layoutKind == LayoutKind::InstData) {
+    dominantLayout = dominantLayout.dropSgLayoutAndData();
   }
-  auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
-  // Given that the result is 1D, the layout of the operand should be 2D with
-  // default layout.
-  LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(
-      reduction->getContext(), 2, uArch->getSubgroupSize());
-  propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
+
+  // record the dominant layout to the reduction op
+  xegpu::setTemporaryLayout(reduction->getResult(0), dominantLayout);
+
+  // derive the source layout from the dominant layout and reduction dims
+  auto srcLayoutAttr =
+      xegpu::inferReductionSourceLayout(dominantLayout, reductionDims);
+  // void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
+
+  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcLayoutAttr = "
+                    << srcLayoutAttr << "\n");
+
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
   // Accumulator should have the same layout as the result.
   propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
 }
@@ -637,6 +671,7 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   LayoutInfo resultLayout = results[0]->getValue();
   if (!resultLayout.isAssigned())
     return;
+
   // Only consider vector to vector broadcasts for now.
   VectorType resultTy = broadcast.getResultVectorType();
   VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
@@ -644,20 +679,23 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   if (!sourceTy)
     return;
 
-  // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
-  if (sourceTy.getRank() != resultTy.getRank()) {
-    auto srcShape = sourceTy.getShape();
-    auto resShape = resultTy.getShape();
-    auto resultLayoutAttr =
-        dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+  auto srcShape = sourceTy.getShape();
+  auto resShape = resultTy.getShape();
 
-    xegpu::DistributeLayoutAttr sliceLayout = xegpu::inferBroadCastSourceLayout(
-        broadcast.getContext(), resultLayoutAttr, resShape, srcShape);
+  size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
+  for (size_t i = 0; i < srcShape.size(); i++)
+    if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
+      broadcast.emitWarning("broadcast must either from low-rank or same-rank "
+                            "with unit-dim, mixed scenario is not supported!");
 
-    propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
-    return;
-  }
-  propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
+  auto resultLayoutAttr =
+      dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+
+  xegpu::DistributeLayoutAttr resLayout = xegpu::inferBroadCastSourceLayout(
+      broadcast.getContext(), resultLayoutAttr, resShape, srcShape);
+
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(resLayout)));
+  return;
 }
 
 void LayoutInfoPropagation::visitShapeCastOp(
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 2f68aa7bda924..58e222812661f 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -263,6 +263,35 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
   return !result.wasInterrupted();
 }
 
+// Prerequisite for Layout Recovery
+// It relies on the following invariant:
+// 1. there is no layout conflict between different uses of the same definition.
+// 2. each definition has a well-defined layout requirement at its use point.
+//     - Every definition must have at least one use that appears after it in
+//     topological order.
+//     - If a definition has no such use (e.g., a loop result or region output),
+//     an explicit convert_layout operation is inserted to create a use.
+//     - Only the result of convert_layout is permitted to have no subsequent
+//     use.
+
+// The recover proceeds by scanning the operation in reverse topological orderas
+// follows: Across operations: layouts are propagated from uses to definitions.
+// Within an operation: layouts are propagated from definitions (result) to uses
+// (operands).
+//    For region operations (e.g., loops):
+//       - When backward propagation reaches a region op, it sets the layout of
+//       the region op’s results according to use points like regular ops.
+//       - Then, the result layouts (such as a loop output) are propagated to
+//       thiers corresponding operands in the yield.
+//       - When backward propagation reaches the first operation inside the
+//       region, the pass examines the region op’s initialization list,
+//       propagating from region arguments to the corresponding initialization
+//       operands.
+//       - This ensures that layout constraints are consistently propagated
+//       across region boundaries
+//        while preserving a single well-defined use for each definition at the
+//        region-op level.
+
 template <typename T, typename>
 void xegpu::removeLayoutAttr(const T &operandOrResult) {
   Operation *owner = operandOrResult.getOwner();
@@ -297,54 +326,365 @@ void xegpu::removeLayoutAttrs(Operation *op) {
 }
 
 /// Infers the source layout attribute for a broadcast operation given the
-/// result layout attribute, result shape, source shape, and broadcasted dims.
+/// result layout attribute, result shape, source shape.
 xegpu::DistributeLayoutAttr xegpu::inferBroadCastSourceLayout(
     MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
     ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
 
   SmallVector<int64_t> bcastDims;
+  auto returnLayout = resLayout;
+
+  // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
   int dimDiff = resShape.size() - srcShape.size();
-  // adding the missing leading dims
-  for (int i = 0; i < dimDiff; i++)
-    bcastDims.push_back(i);
 
-  // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
-  // broadcasted dim
-  // for (size_t i = 0; i < srcShape.size(); i++)
-  //   if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
-  //     bcastDims.push_back(i + dimDiff);
+  if (dimDiff > 0) {
+    // adding the missing leading dims
+    for (int i = 0; i < dimDiff; i++)
+      bcastDims.push_back(i);
+
+    // create a slice layout for the source
+    returnLayout = xegpu::SliceAttr::get(
+        context, resLayout, DenseI64ArrayAttr::get(context, bcastDims));
+  }
+  return returnLayout;
+}
+
+/// Infers the source layout attribute for a reduction operation given the
+/// result layout attribute and reduced dims.
+xegpu::DistributeLayoutAttr
+xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                                  SmallVector<int64_t> reduceDims) {
+
+  //  assert the resLayout must be slice layout
+  assert(isa<xegpu::SliceAttr>(resLayout) &&
+         "reduction result layout must be slice layout");
 
-  // create a slice layout for the source
-  xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
-      context, resLayout, DenseI64ArrayAttr::get(context, bcastDims));
+  // assert that the reduceDims must match with the slice dims of resLayout
+  xegpu::SliceAttr sliceLayout = dyn_cast<xegpu::SliceAttr>(resLayout);
+  auto sliceDims = sliceLayout.getDims().asArrayRef();
+  assert(reduceDims == sliceDims &&
+         "reduction dims must match with slice dims");
 
-  return sliceLayout;
+  //  then return the parent layout of sliceLayout
+  return sliceLayout.getParent();
 }
 
-// /// Infers the source layout attribute for a reduction operation given the
-// /// result layout attribute, result shape, source shape, and reduced dims.
-// xegpu::DistributeLayoutAttr xegpu::inferReductionSourceLayout(
-//         MLIRContext *context,
-//         xegpu::DistributeLayoutAttr resLayout,
-//         SmallVector<int64_t> reduceDims){
+// /// Infers the source layout attribute for a bitcast operation given the
+// /// result layout attribute, result element type bitwidth, and source element
+// /// type bitwidth.
+// xegpu::DistributeLayoutAttr
+// xegpu::inferBitCastSourceLayout(MLIRContext *context,
+//                                 xegpu::DistributeLayoutAttr resLayout,
+//                                 int resElemTyBitWidth, int
+//                                 srcElemTyBitWidth){
+//   // the result and source layout must be the same
+//   // if resLayout is SliceAttr, we need to first get its root layout
+//   xegpu::DistributeLayoutAttr resRootLayout = resLayout;
+//   while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resRootLayout)) {
+//     resRootLayout = sliceLayout.getParent();
+//   }
+//   // change the laneData of resRootLayout according to the bitwidth ratio
+//   xegpu::LayoutAttr resRootPlainLayout =
+//   dyn_cast<xegpu::LayoutAttr>(resRootLayout); SmallVector<int64_t> laneData =
+//   resRootPlainLayout.getEffectiveLaneDataAsInt();
+
+//   if (srcElemTyBitWidth >= resElemTyBitWidth) {
+//     int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+//     laneData[laneData.size()-1] = laneData[laneData.size()-1] *
+//     bitWidthRatio;
+//   } else {
+//     int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
+//     assert((laneData[laneData.size()-2] % bitWidthRatio) == 0 &&
+//            "laneData not divisible by bitWidthRatio");
+//     laneData[laneData.size()-1] = laneData[laneData.size()-1] /
+//     bitWidthRatio;
+//   }
+
+//   // now reconstruct the source layout with updated laneData
+//   // by updating the root layout and going throught the slice layers
+//   SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+//   xegpu::LayoutAttr proposedSrcLayout = xegpu::LayoutAttr::get(
+//           context,
+//           resRootPlainLayout.getSgLayout(),
+//           resRootPlainLayout.getSgData(),
+//           resRootPlainLayout.getInstData(),
+//           resRootPlainLayout.getLaneLayout(),
+//           DenseI32ArrayAttr::get(context, laneData32),
+//           resRootPlainLayout.getOrder());
+
+//   // reconstruct slice layers if any
+//   // First collect all slice layers from innermost to outermost
+//   SmallVector<DenseI64ArrayAttr> sliceDims;
+//   xegpu::DistributeLayoutAttr currentLayout = resLayout;
+//   while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(currentLayout)) {
+//     sliceDims.push_back(sliceLayout.getDims());
+//     currentLayout = sliceLayout.getParent();
+//   }
+
+//   // Now rebuild from outermost to innermost (reverse order)
+//   xegpu::DistributeLayoutAttr finalSrcLayout = proposedSrcLayout;
+//   for (auto it = sliceDims.rbegin(); it != sliceDims.rend(); ++it) {
+//     finalSrcLayout = xegpu::SliceAttr::get(context, finalSrcLayout, *it);
+//   }
+//   return finalSrcLayout;
+// }
+
+// /// Infers the source layout attribute for a shape cast operation given the
+// /// result layout attribute, result shape, and source shape.
+// xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
+//     MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
+//     ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape){
+
+// // there are two use cases:
+// // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
+// tensor before broadcast
+// // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
+// for multi-stage reduction
+
+//   SmallVector<int64_t> shapeCastDims;
+//   auto returnLayout = resLayout;
+
+//   int resRank = resShape.size();
+//   int srcRank = srcShape.size();
+
+//   if (srcRank < resRank) {
+//     // Case 1: expand dims of low-rank dimensions (e.g., 1D to 2D)
+//     int dimDiff = resRank - srcRank;
+//     // adding the missing leading dims
+//     for (int i = 0; i < dimDiff; i++)
+//       shapeCastDims.push_back(i);
+
+//     // create a slice layout for the source
+//     returnLayout = xegpu::SliceAttr::get(
+//         context, resLayout, DenseI64ArrayAttr::get(context, shapeCastDims));
+//   } else if (srcRank > resRank) {
+//     // Case 2: split dim of a high-rank dimension (e.g., 1D to 2D)
+//     // find the split dims by comparing srcShape and resShape
+//     int srcIdx = 0;
+//     int resIdx = 0;
+//     while (srcIdx < srcRank && resIdx < resRank) {
+//       if (srcShape[srcIdx] == resShape[resIdx]) {
+//         srcIdx++;
+//         resIdx++;
+//       } else if (srcShape[srcIdx] < resShape[resIdx]) {
+//         shapeCastDims.push_back(srcIdx);
+//         srcIdx++;
+//       } else {
+//         // this should not happen in valid shape cast
+//         assert(false && "Invalid shape cast: source shape dimension smaller
+//         than result shape dimension");
+//       }
+//     }
+//     // handle remaining src dims
+//     while (srcIdx < srcRank) {
+//       shapeCastDims.push_back(srcIdx);
+//       srcIdx++;
+//     }
+
+//     // create a slice layout for the source
+//     returnLayout = xegpu::SliceAttr::get(
+//         context, resLayout, DenseI64ArrayAttr::get(context, shapeCastDims));
+//   }
+//   return returnLayout;
+
+// }
+
+xegpu::SliceAttr
+xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
+                                SmallVector<int64_t> reductionDims,
+                                DistributeLayoutAttr consumerPreferredLayout) {
+
+  xegpu::SliceAttr sliceCPL =
+      dyn_cast<xegpu::SliceAttr>(consumerPreferredLayout);
+
+  // try to align wiht customer's preferred layout so that the slice layout
+  // structure is preserved, and thus avoid potential data movement acorss sg or
+  // lanes.
+
+  const int workgroupSize = 16; // assuming 16 subgroups for now
+  const int subgroupSize = 16;  // assuming 16 lanes per subgroup
+  const int vectorSize = 8;     // assuming 8 elements per vector lane
+  int srcShapeSize = srcShape.size();
+  xegpu::DistributeLayoutAttr proposedSrcLayout;
+  auto context = consumerPreferredLayout.getContext();
+  // if srcShapeSize is less than 2, we cannot proceed
+  if (srcShapeSize < 2)
+    return nullptr;
 
-//    //  flatten the reslayout first
-//    xegpu::DistributeLayoutAttr flatResLayout =
-//        resLayout.flatten();
-//   //  then unslice the reduceDims in the flattened layout
+  llvm::errs() << "DEBUG: Entering \n";
 
-//}
+  SmallVector<int64_t> sgLayout(srcShapeSize);
+  SmallVector<int64_t> sgData(srcShapeSize);
 
-/// Infers the source layout attribute for a bitcast operation given the
-/// result layout attribute, result element type bitwidth, and source element
-/// type bitwidth.
-xegpu::DistributeLayoutAttr
-xegpu::inferBitCastSourceLayout(MLIRContext *context,
-                                xegpu::DistributeLayoutAttr resLayout,
-                                int resElemTyBitWidth, int srcElemTyBitWidth);
+  SmallVector<int64_t> instData(srcShapeSize, 1);
+  instData[srcShapeSize - 1] = subgroupSize;
+  instData[srcShapeSize - 2] =
+      vectorSize; // assuming 8 elements per instruction as starting point
+  llvm::errs() << "DEBUG: Initial instData = [";
+  for (size_t i = 0; i < instData.size(); i++) {
+    llvm::errs() << instData[i];
+    if (i < instData.size() - 1)
+      llvm::errs() << ", ";
+  }
+  llvm::errs() << "]\n";
+  // construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
+  SmallVector<int64_t> laneLayout(srcShapeSize, 1);
+  laneLayout[srcShapeSize - 1] = subgroupSize;
+  llvm::errs() << "DEBUG: laneLayout = [";
+  for (size_t i = 0; i < laneLayout.size(); i++) {
+    llvm::errs() << laneLayout[i];
+    if (i < laneLayout.size() - 1)
+      llvm::errs() << ", ";
+  }
+  llvm::errs() << "]\n";
+  // construct a vector layout with lane_data = [1, ..., 1]
+  SmallVector<int64_t> laneData(srcShapeSize, 1);
+
+  bool failToAlignSliceStruct = false;
+  if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims)) {
+
+    xegpu::DistributeLayoutAttr parentCPL = sliceCPL.getParent();
+
+    // for each slice dim in source shape, if the dim size is differnt than the
+    // result shape, try to adjust the sg_data/inst_data accordingly.
+    SmallVector<int64_t> pcplSgLayout = parentCPL.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> pcplLaneLayout =
+        parentCPL.getEffectiveLaneLayoutAsInt();
+    SmallVector<int64_t> pcplLaneData = parentCPL.getEffectiveLaneDataAsInt();
+
+    assert(srcShapeSize == parentCPL.getRank() &&
+           "parent layout rank must match source shape rank");
+
+    proposedSrcLayout = parentCPL;
+
+    llvm::errs() << "DEBUG: srcShapeSize = " << srcShapeSize << "\n";
+    llvm::errs() << "DEBUG: parentCPL rank = " << parentCPL.getRank() << "\n";
+    llvm::errs() << "DEBUG: srcShape = [";
+    for (int i = 0; i < srcShapeSize; i++) {
+      llvm::errs() << srcShape[i];
+      if (i < srcShapeSize - 1)
+        llvm::errs() << ", ";
+    }
+    llvm::errs() << "]\n";
+
+    if (pcplSgLayout.size() == static_cast<size_t>(srcShapeSize)) {
+      for (int i = 0; i < srcShapeSize; i++) {
+        if (srcShape[i] % pcplSgLayout[i] == 0) {
+          sgLayout[i] = pcplSgLayout[i];
+          sgData[i] = srcShape[i] / sgLayout[i];
+          instData[i] = std::min(instData[i], sgData[i]);
+        } else {
+          failToAlignSliceStruct = true;
+          break;
+        }
+      }
+    }
 
-/// Infers the source layout attribute for a shape cast operation given the
-/// result layout attribute, result shape, and source shape.
-xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
-    MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
-    ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape);
+    if (pcplLaneLayout.size() == static_cast<size_t>(srcShapeSize)) {
+      for (int i = 0; i < srcShapeSize; i++) {
+        if (instData[i] % pcplLaneLayout[i] == 0) {
+          laneLayout[i] = pcplLaneLayout[i];
+          laneData[i] = pcplLaneData[i];
+        } else {
+          failToAlignSliceStruct = true;
+          break;
+        }
+      }
+    }
+  } else {
+    failToAlignSliceStruct = true;
+  }
+
+  if (failToAlignSliceStruct) {
+
+    // try to align the sg layout
+    SmallVector<int64_t> cplSgLayout =
+        consumerPreferredLayout.getEffectiveSgLayoutAsInt();
+    llvm::errs() << "DEBUG: cplSgLayout size = " << cplSgLayout.size() << "\n";
+    // if sg layout doesn't cover all the sg ids, distribute rest to
+    // non-reduction dims
+    int remainingSgCount = workgroupSize;
+
+    SmallVector<int64_t> remainingDims;
+    // print debug info for consumerPreferredLayout and cplSgLayout
+    llvm::errs() << "DEBUG: consumerPreferredLayout sgLayout = [";
+    auto cplSgLayoutFull = consumerPreferredLayout.getEffectiveSgLayoutAsInt();
+    for (size_t i = 0; i < cplSgLayoutFull.size(); i++) {
+      llvm::errs() << cplSgLayoutFull[i];
+      if (i < cplSgLayoutFull.size() - 1)
+        llvm::errs() << ", ";
+    }
+    // if cplSgLayout is not empty, try to align the sg layout first
+    int cplId = cplSgLayout.size() - 1;
+    llvm::errs() << "DEBUG: Starting first loop, cplId = " << cplId << "\n";
+    for (int i = srcShapeSize - 1; i >= 0; i--) {
+      llvm::errs() << "DEBUG: Loop 1, i = " << i << ", is_reduction_dim = "
+                   << llvm::is_contained(reductionDims, i) << "\n";
+      // try to align with cplSgLayout first for non-reduction dims
+      if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
+        if (srcShape[i] % cplSgLayout[cplId] == 0) {
+          sgLayout[i] = cplSgLayout[cplId];
+          sgData[i] = srcShape[i] / sgLayout[i];
+          instData[i] = std::min(instData[i], sgData[i]);
+          remainingSgCount /= sgLayout[i];
+          llvm::errs() << "DEBUG: Set sgLayout[" << i << "] = " << sgLayout[i]
+                       << ", sgData[" << i << "] = " << sgData[i]
+                       << ", remainingSgCount = " << remainingSgCount << "\n";
+          cplId--;
+          continue;
+        }
+      }
+      remainingDims.push_back(i);
+      llvm::errs() << "DEBUG: Added i = " << i << " to remainingDims\n";
+    }
+
+    llvm::errs() << "DEBUG: Starting second loop\n";
+    for (int i = srcShapeSize - 1; i >= 0; i--) {
+      llvm::errs() << "DEBUG: Loop 2, i = " << i << ", is_remaining_dim = "
+                   << llvm::is_contained(remainingDims, i) << "\n";
+      if (llvm::is_contained(remainingDims, i)) {
+
+        llvm::errs() << "DEBUG: Before Set sgLayout[" << i
+                     << "] = " << sgLayout[i] << ", sgData[" << i
+                     << "] = " << sgData[i]
+                     << ", remainingSgCount = " << remainingSgCount << "\n";
+
+        sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
+                               static_cast<int64_t>(remainingSgCount));
+        sgData[i] = srcShape[i] / sgLayout[i];
+        instData[i] = std::min(instData[i], sgData[i]);
+        remainingSgCount /= sgLayout[i];
+
+        llvm::errs() << "DEBUG: After Set sgLayout[" << i
+                     << "] = " << sgLayout[i] << ", sgData[" << i
+                     << "] = " << sgData[i]
+                     << ", remainingSgCount = " << remainingSgCount << "\n";
+
+        if (remainingSgCount == 1) {
+          llvm::errs() << "DEBUG: Breaking from loop 2, remainingSgCount = 1\n";
+          break;
+        }
+      }
+    }
+  }
+  // Convert int64_t vectors to int32_t for DenseI32ArrayAttr
+  SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
+  SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
+  SmallVector<int32_t> instData32(instData.begin(), instData.end());
+  SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
+  SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+  proposedSrcLayout = xegpu::LayoutAttr::get(
+      context, DenseI32ArrayAttr::get(context, sgLayout32),
+      DenseI32ArrayAttr::get(context, sgData32),
+      DenseI32ArrayAttr::get(context, instData32),
+      DenseI32ArrayAttr::get(context, laneLayout32),
+      DenseI32ArrayAttr::get(context, laneData32),
+      consumerPreferredLayout.getOrder());
+
+  // finally, create the slice layout for reduction source
+  xegpu::SliceAttr reductionSrcLayout =
+      xegpu::SliceAttr::get(context, proposedSrcLayout,
+                            DenseI64ArrayAttr::get(context, reductionDims));
+
+  return reductionSrcLayout;
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index b88d8e1a78a26..cccc446f2e4c4 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -483,7 +483,7 @@ func.func @if_multiple_uses(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.t
 gpu.module @test {
 // CHECK-LABEL: func.func @vector_outer_reduction(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
-// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf32> to vector<16xf32>
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf32> to vector<16xf32>
 func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [0] : vector<16x16xf32> to vector<16xf32>
@@ -495,7 +495,7 @@ func.func @vector_outer_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor
 gpu.module @test {
 // CHECK-LABEL: func.func @vector_inner_reduction(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: vector<16x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
-// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
+// CHECK: %{{.*}} = vector.multi_reduction <add>, %[[ARG0]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf32> to vector<16xf32>
 func.func @vector_inner_reduction(%arg0: vector<16x16xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
   %cst = arith.constant dense<0.000000e+00> : vector<16xf32>
   %0 = vector.multi_reduction <add>, %arg0, %cst [1] : vector<16x16xf32> to vector<16xf32>

>From a347f16fdcf87e69eda55fcfd325fc0594541ae3 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 24 Dec 2025 06:26:56 +0000
Subject: [PATCH 03/35] add infer rule for bitcast and shapecast

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  18 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 118 ++++++++
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 286 ++++++++++--------
 3 files changed, 300 insertions(+), 122 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index eae2cbe24a72c..473e485c1faee 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -243,7 +243,13 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     (ins "int64_t": $dim,
                           "int64_t": $sgData,
                           "int64_t": $instData,
-                          "int64_t": $laneData)>,                                                 
+                          "int64_t": $laneData)>,              
+    InterfaceMethod<[{Derive a new layout by collapsing groups of dimensions. Each inner array in
+                      `dimGroups` specifies a group of dimensions that are collapsed into a single
+                      dimension in the derived layout.}],
+                    "xegpu::DistributeLayoutAttr",
+                    "collapseDims",
+                    (ins "ArrayRef<ArrayRef<int64_t>>": $dimGroups)>,                                                   
     InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
                       assigned to a level identified by linearId. The shape parameter
                       represents the higher-level problem size. Each level may access
@@ -518,6 +524,11 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     // specified values for the given dimension
     DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
+    // Derive a new layout by collapsing groups of dimensions.
+    // Each inner array in `dimGroups` specifies a set of dimensions
+    // that are collapsed into a single dimension in the derived layout.
+    DistributeLayoutAttr collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups);
+
     /// Delinearizes a linear ID into its multidimensional indices
     /// based on the effective level of the layout.
     FailureOr<SmallVector<Value>>
@@ -693,6 +704,11 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     // specified values for the given dimension
     DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
+    // Derive a new layout by collapsing groups of dimensions.
+    // Each inner array in `dimGroups` specifies a set of dimensions
+    // that are collapsed into a single dimension in the derived layout.
+    DistributeLayoutAttr collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups);
+
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
     /// it will coalese two slice operations and return a simplified SliceAttr
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 0ecfe3eac650c..613651ec7f964 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -512,6 +512,102 @@ DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
       getOrder());
 }
 
+// Derive a new layout by collapsing groups of dimensions.
+// Each inner array in `dimGroups` specifies a set of dimensions
+// that are collapsed into a single dimension in the derived layout.
+DistributeLayoutAttr
+LayoutAttr::collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups) {
+
+  // Extract layout attributes as vectors
+  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
+  SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
+
+  DenseI32ArrayAttr orderAttr = getOrder();
+  SmallVector<int64_t> order;
+  if (orderAttr && !orderAttr.empty()) {
+    order = llvm::to_vector(
+        llvm::map_range(orderAttr.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+  } else {
+    // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+    order =
+        llvm::to_vector(llvm::reverse(llvm::seq<int64_t>(0, sgLayout.size())));
+  }
+
+  SmallVector<int64_t> collapsedSgLayout;
+  SmallVector<int64_t> collapsedSgData;
+  SmallVector<int64_t> collapsedInstData;
+  SmallVector<int64_t> collapsedLaneLayout;
+  SmallVector<int64_t> collapsedLaneData;
+  SmallVector<int64_t> collapsedOrder;
+
+  for (const auto &group : dimGroups) {
+
+    // Collapse by multiplying values across dimension group
+    int64_t collapsedSg = 1, collapsedSgD = 1, collapsedInst = 1;
+    int64_t collapsedLaneL = 1, collapsedLaneD = 1;
+    int64_t collapsedOrderValue = -1;
+
+    for (int64_t dimIdx : group) {
+      collapsedSg *= sgLayout[dimIdx];
+      collapsedSgD *= sgData[dimIdx];
+      collapsedInst *= instData[dimIdx];
+      collapsedLaneL *= laneLayout[dimIdx];
+      collapsedLaneD *= laneData[dimIdx];
+      collapsedOrderValue = order[dimIdx]; // take the last one's order
+    }
+
+    collapsedSgLayout.push_back(collapsedSg);
+    collapsedSgData.push_back(collapsedSgD);
+    collapsedInstData.push_back(collapsedInst);
+    collapsedLaneLayout.push_back(collapsedLaneL);
+    collapsedLaneData.push_back(collapsedLaneD);
+    collapsedOrder.push_back(collapsedOrderValue);
+  }
+
+  // go through the values inside collapsedOrder, and re-map the order values to
+  // be in range of [0, N-1] where N is the number of dimensions in collapsed
+  // shape
+  int64_t orderSize = static_cast<int64_t>(collapsedOrder.size());
+  SmallVector<int64_t> remappedOrder(orderSize, -1);
+  for (int64_t i = 0; i < orderSize; ++i) {
+    int64_t originalOrderValue = collapsedOrder[i];
+    // count how many values in collapsedOrder are less than originalOrderValue
+    int64_t count = 0;
+    for (int64_t j = 0; j < orderSize; ++j) {
+      if (collapsedOrder[j] < originalOrderValue)
+        count++;
+    }
+    remappedOrder[i] = count;
+  }
+
+  // Create collapsed layout
+  SmallVector<int32_t> collapsedSgLayout32(collapsedSgLayout.begin(),
+                                           collapsedSgLayout.end());
+  SmallVector<int32_t> collapsedSgData32(collapsedSgData.begin(),
+                                         collapsedSgData.end());
+  SmallVector<int32_t> collapsedInstData32(collapsedInstData.begin(),
+                                           collapsedInstData.end());
+  SmallVector<int32_t> collapsedLaneLayout32(collapsedLaneLayout.begin(),
+                                             collapsedLaneLayout.end());
+  SmallVector<int32_t> collapsedLaneData32(collapsedLaneData.begin(),
+                                           collapsedLaneData.end());
+  SmallVector<int32_t> remappedOrder32(remappedOrder.begin(),
+                                       remappedOrder.end());
+
+  auto collapsedLayout = xegpu::LayoutAttr::get(
+      getContext(), DenseI32ArrayAttr::get(getContext(), collapsedSgLayout32),
+      DenseI32ArrayAttr::get(getContext(), collapsedSgData32),
+      DenseI32ArrayAttr::get(getContext(), collapsedInstData32),
+      DenseI32ArrayAttr::get(getContext(), collapsedLaneLayout32),
+      DenseI32ArrayAttr::get(getContext(), collapsedLaneData32),
+      DenseI32ArrayAttr::get(getContext(), remappedOrder32));
+  return collapsedLayout;
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_SliceAttr
 //===----------------------------------------------------------------------===//
@@ -723,6 +819,28 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
       parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
 }
 
+// Derive a new layout by collapsing groups of dimensions.
+// Each inner array in `dimGroups` specifies a set of dimensions
+// that are collapsed into a single dimension in the derived layout.
+DistributeLayoutAttr
+SliceAttr::collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups) const {
+
+  // Map the sliced dims from parent space to collapsed space
+  ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
+
+  // go through dimGroups and map each dim from sliced space to parent space
+  SmallVector<SmallVector<int64_t>> adjustedDimGroups;
+  for (const auto &group : dimGroups) {
+    SetVector<int64_t> mappedDims = mapDimsFromSlicedSpace(group, sliceDims);
+    adjustedDimGroups.push_back(mappedDims.getArrayRef());
+  }
+
+  auto collapsedParent = getParent().collapseDims(adjustedDimGroups);
+
+  return SliceAttr::get(getContext(), collapsedParent,
+                        DenseI64ArrayAttr::get(getContext(), sliceDims));
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_RangeAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 58e222812661f..acb5f07893ee2 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -291,6 +291,7 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
 //       across region boundaries
 //        while preserving a single well-defined use for each definition at the
 //        region-op level.
+bool xegpu::recoverTemporaryLayouts_first(Operation *rootOp) {}
 
 template <typename T, typename>
 void xegpu::removeLayoutAttr(const T &operandOrResult) {
@@ -369,125 +370,168 @@ xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   return sliceLayout.getParent();
 }
 
-// /// Infers the source layout attribute for a bitcast operation given the
-// /// result layout attribute, result element type bitwidth, and source element
-// /// type bitwidth.
-// xegpu::DistributeLayoutAttr
-// xegpu::inferBitCastSourceLayout(MLIRContext *context,
-//                                 xegpu::DistributeLayoutAttr resLayout,
-//                                 int resElemTyBitWidth, int
-//                                 srcElemTyBitWidth){
-//   // the result and source layout must be the same
-//   // if resLayout is SliceAttr, we need to first get its root layout
-//   xegpu::DistributeLayoutAttr resRootLayout = resLayout;
-//   while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resRootLayout)) {
-//     resRootLayout = sliceLayout.getParent();
-//   }
-//   // change the laneData of resRootLayout according to the bitwidth ratio
-//   xegpu::LayoutAttr resRootPlainLayout =
-//   dyn_cast<xegpu::LayoutAttr>(resRootLayout); SmallVector<int64_t> laneData =
-//   resRootPlainLayout.getEffectiveLaneDataAsInt();
-
-//   if (srcElemTyBitWidth >= resElemTyBitWidth) {
-//     int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
-//     laneData[laneData.size()-1] = laneData[laneData.size()-1] *
-//     bitWidthRatio;
-//   } else {
-//     int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
-//     assert((laneData[laneData.size()-2] % bitWidthRatio) == 0 &&
-//            "laneData not divisible by bitWidthRatio");
-//     laneData[laneData.size()-1] = laneData[laneData.size()-1] /
-//     bitWidthRatio;
-//   }
-
-//   // now reconstruct the source layout with updated laneData
-//   // by updating the root layout and going throught the slice layers
-//   SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
-//   xegpu::LayoutAttr proposedSrcLayout = xegpu::LayoutAttr::get(
-//           context,
-//           resRootPlainLayout.getSgLayout(),
-//           resRootPlainLayout.getSgData(),
-//           resRootPlainLayout.getInstData(),
-//           resRootPlainLayout.getLaneLayout(),
-//           DenseI32ArrayAttr::get(context, laneData32),
-//           resRootPlainLayout.getOrder());
-
-//   // reconstruct slice layers if any
-//   // First collect all slice layers from innermost to outermost
-//   SmallVector<DenseI64ArrayAttr> sliceDims;
-//   xegpu::DistributeLayoutAttr currentLayout = resLayout;
-//   while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(currentLayout)) {
-//     sliceDims.push_back(sliceLayout.getDims());
-//     currentLayout = sliceLayout.getParent();
-//   }
-
-//   // Now rebuild from outermost to innermost (reverse order)
-//   xegpu::DistributeLayoutAttr finalSrcLayout = proposedSrcLayout;
-//   for (auto it = sliceDims.rbegin(); it != sliceDims.rend(); ++it) {
-//     finalSrcLayout = xegpu::SliceAttr::get(context, finalSrcLayout, *it);
-//   }
-//   return finalSrcLayout;
-// }
-
-// /// Infers the source layout attribute for a shape cast operation given the
-// /// result layout attribute, result shape, and source shape.
-// xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
-//     MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
-//     ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape){
-
-// // there are two use cases:
-// // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
-// tensor before broadcast
-// // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
-// for multi-stage reduction
-
-//   SmallVector<int64_t> shapeCastDims;
-//   auto returnLayout = resLayout;
-
-//   int resRank = resShape.size();
-//   int srcRank = srcShape.size();
-
-//   if (srcRank < resRank) {
-//     // Case 1: expand dims of low-rank dimensions (e.g., 1D to 2D)
-//     int dimDiff = resRank - srcRank;
-//     // adding the missing leading dims
-//     for (int i = 0; i < dimDiff; i++)
-//       shapeCastDims.push_back(i);
-
-//     // create a slice layout for the source
-//     returnLayout = xegpu::SliceAttr::get(
-//         context, resLayout, DenseI64ArrayAttr::get(context, shapeCastDims));
-//   } else if (srcRank > resRank) {
-//     // Case 2: split dim of a high-rank dimension (e.g., 1D to 2D)
-//     // find the split dims by comparing srcShape and resShape
-//     int srcIdx = 0;
-//     int resIdx = 0;
-//     while (srcIdx < srcRank && resIdx < resRank) {
-//       if (srcShape[srcIdx] == resShape[resIdx]) {
-//         srcIdx++;
-//         resIdx++;
-//       } else if (srcShape[srcIdx] < resShape[resIdx]) {
-//         shapeCastDims.push_back(srcIdx);
-//         srcIdx++;
-//       } else {
-//         // this should not happen in valid shape cast
-//         assert(false && "Invalid shape cast: source shape dimension smaller
-//         than result shape dimension");
-//       }
-//     }
-//     // handle remaining src dims
-//     while (srcIdx < srcRank) {
-//       shapeCastDims.push_back(srcIdx);
-//       srcIdx++;
-//     }
-
-//     // create a slice layout for the source
-//     returnLayout = xegpu::SliceAttr::get(
-//         context, resLayout, DenseI64ArrayAttr::get(context, shapeCastDims));
-//   }
-//   return returnLayout;
-
-// }
+/// Infers the source layout attribute for a bitcast operation given the
+/// result layout attribute, result element type bitwidth, and source element
+/// type bitwidth.
+xegpu::DistributeLayoutAttr
+xegpu::inferBitCastSourceLayout(MLIRContext *context,
+                                xegpu::DistributeLayoutAttr resLayout,
+                                int resElemTyBitWidth, int srcElemTyBitWidth) {
+  // the result and source layout must be the same
+  // if resLayout is SliceAttr, we need to first get its root layout
+  xegpu::DistributeLayoutAttr resRootLayout = resLayout;
+  while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resRootLayout)) {
+    resRootLayout = sliceLayout.getParent();
+  }
+  // change the laneData of resRootLayout according to the bitwidth ratio
+  xegpu::LayoutAttr resRootPlainLayout =
+      dyn_cast<xegpu::LayoutAttr>(resRootLayout);
+  SmallVector<int64_t> laneData =
+      resRootPlainLayout.getEffectiveLaneDataAsInt();
+
+  if (srcElemTyBitWidth >= resElemTyBitWidth) {
+    int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+    laneData[laneData.size() - 1] =
+        laneData[laneData.size() - 1] * bitWidthRatio;
+  } else {
+    int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
+    assert((laneData[laneData.size() - 2] % bitWidthRatio) == 0 &&
+           "laneData not divisible by bitWidthRatio");
+    laneData[laneData.size() - 1] =
+        laneData[laneData.size() - 1] / bitWidthRatio;
+  }
+
+  // now reconstruct the source layout with updated laneData
+  // by updating the root layout and going throught the slice layers
+  SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+  xegpu::LayoutAttr proposedSrcLayout = xegpu::LayoutAttr::get(
+      context, resRootPlainLayout.getSgLayout(), resRootPlainLayout.getSgData(),
+      resRootPlainLayout.getInstData(), resRootPlainLayout.getLaneLayout(),
+      DenseI32ArrayAttr::get(context, laneData32),
+      resRootPlainLayout.getOrder());
+
+  // reconstruct slice layers if any
+  // First collect all slice layers from innermost to outermost
+  SmallVector<DenseI64ArrayAttr> sliceDims;
+  xegpu::DistributeLayoutAttr currentLayout = resLayout;
+  while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(currentLayout)) {
+    sliceDims.push_back(sliceLayout.getDims());
+    currentLayout = sliceLayout.getParent();
+  }
+
+  // Now rebuild from outermost to innermost (reverse order)
+  xegpu::DistributeLayoutAttr finalSrcLayout = proposedSrcLayout;
+  for (auto it = sliceDims.rbegin(); it != sliceDims.rend(); ++it) {
+    finalSrcLayout = xegpu::SliceAttr::get(context, finalSrcLayout, *it);
+  }
+  return finalSrcLayout;
+}
+
+/// Infers the source layout attribute for a shape cast operation given the
+/// result layout attribute, result shape, and source shape.
+xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
+    MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
+    ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
+
+  // there are two use cases:
+  // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
+  // tensor before broadcast
+  // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
+  // for multi-stage reduction
+  // 3. combines all dims to a single dim and put in the innermost dim in 2d as
+  // [1, combinedData]. only used after workgroup distribution to save
+  // multidimension data to 1D slm buffer so no need to handle sg_layout and
+  // sg_data.
+
+  // Use case 1: Check if shapes only differ by expanding unit dimensions (like
+  // expand_dims)
+  SmallVector<int64_t> expandedUnitDims;
+  auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
+                                     ArrayRef<int64_t> dst) -> bool {
+    // All unit dimensions in dst that don't appear in src are the expanded
+    // unit dimensions
+    size_t srcIdx = 0;
+    for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+      if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+        srcIdx++;
+      else if (dst[dstIdx] == 1)
+        expandedUnitDims.push_back(dstIdx);
+      else
+        return false;
+    return srcIdx == src.size();
+  };
+
+  if (checkOnlyExpandUnitDims(srcShape, resShape)) {
+    // create a slice layout for the source by removing the expanded unit dims
+    auto sliceDimsAttr =
+        DenseI64ArrayAttr::get(context, ArrayRef<int64_t>(expandedUnitDims));
+    auto srcLayout = xegpu::SliceAttr::get(context, resLayout, sliceDimsAttr);
+    return srcLayout;
+  }
+
+  // Maps each source dimension to the range of destination dimensions it splits
+  // into
+  SmallVector<SmallVector<size_t>> splitDimGroups;
+
+  auto checkSplitDims = [&](ArrayRef<int64_t> src,
+                            ArrayRef<int64_t> dst) -> bool {
+    // each dim in src can be mapped to one or more dims in dst whose product
+    // equals to the src dim
+    splitDimGroups.clear();
+    size_t srcIdx = 0;
+    int64_t accumulatedSize = 1;
+    SmallVector<size_t> currentDstDims;
+
+    for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
+      if (srcIdx >= src.size())
+        return false;
+      accumulatedSize *= dst[dstIdx];
+      currentDstDims.push_back(dstIdx);
+
+      if (accumulatedSize == src[srcIdx]) {
+        // Record the mapping: srcIdx -> currentDstDims
+        splitDimGroups.push_back(currentDstDims);
+        // move to next src dim
+        srcIdx++;
+        accumulatedSize = 1;
+        currentDstDims.clear();
+      } else if (accumulatedSize > src[srcIdx]) {
+        return false;
+      }
+    }
+    return srcIdx == src.size();
+  };
+
+  if (checkSplitDims(srcShape, resShape)) {
+    return resLayout.collapseDims(splitDimGroups);
+  }
+
+  auto checkCombineToInnerMostDim = [&](ArrayRef<int64_t> src,
+                                        ArrayRef<int64_t> dst) -> bool {
+    // only one non-unit dim in dst which is the innermost dim
+    assert((dst.size() == 2) && "dst shape must be 2D");
+    int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
+                                      std::multiplies<int64_t>());
+    return (dst[0] == 1) && (dst[1] == srcSize);
+  };
+
+  if (checkCombineToInnerMostDim(srcShape, resShape)) {
+    const int subgroupSize = 16; // assuming 16 lanes per subgroup
+    const int vectorSize = 8;    // assuming 8 elements per vector lane
+    int srcShapeSize = srcShape.size();
+
+    SmallVector<int64_t> instData(srcShapeSize, 1);
+    instData[srcShapeSize - 1] = subgroupSize;
+    instData[srcShapeSize - 2] =
+        vectorSize; // assuming 8 elements per instruction as starting point
+
+    // construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
+    SmallVector<int64_t> laneLayout(srcShapeSize, 1);
+    laneLayout[srcShapeSize - 1] = subgroupSize;
+    // construct a vector layout with lane_data = [1, ..., 1]
+    SmallVector<int64_t> laneData(srcShapeSize, 1);
+  }
+}
 
 xegpu::SliceAttr
 xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
@@ -530,6 +574,8 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
   // construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
   SmallVector<int64_t> laneLayout(srcShapeSize, 1);
   laneLayout[srcShapeSize - 1] = subgroupSize;
+  // construct a vector layout with lane_data = [1, ..., 1]
+  SmallVector<int64_t> laneData(srcShapeSize, 1);
   llvm::errs() << "DEBUG: laneLayout = [";
   for (size_t i = 0; i < laneLayout.size(); i++) {
     llvm::errs() << laneLayout[i];
@@ -537,8 +583,6 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
       llvm::errs() << ", ";
   }
   llvm::errs() << "]\n";
-  // construct a vector layout with lane_data = [1, ..., 1]
-  SmallVector<int64_t> laneData(srcShapeSize, 1);
 
   bool failToAlignSliceStruct = false;
   if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims)) {

>From 7e7ae5d0f59bab3382776463d07f870a45b1c21a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 24 Dec 2025 19:21:41 +0000
Subject: [PATCH 04/35] add bitcast set rule

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |   6 +-
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    |  24 ++-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  47 +++--
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |   5 +-
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 168 ++++++++++++------
 5 files changed, 159 insertions(+), 91 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 473e485c1faee..0e13d063498b1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -249,7 +249,7 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       dimension in the derived layout.}],
                     "xegpu::DistributeLayoutAttr",
                     "collapseDims",
-                    (ins "ArrayRef<ArrayRef<int64_t>>": $dimGroups)>,                                                   
+                    (ins "SmallVector<SmallVector<int64_t>>": $dimGroups)>,                                                   
     InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
                       assigned to a level identified by linearId. The shape parameter
                       represents the higher-level problem size. Each level may access
@@ -527,7 +527,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     // Derive a new layout by collapsing groups of dimensions.
     // Each inner array in `dimGroups` specifies a set of dimensions
     // that are collapsed into a single dimension in the derived layout.
-    DistributeLayoutAttr collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups);
+    DistributeLayoutAttr collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const;
 
     /// Delinearizes a linear ID into its multidimensional indices
     /// based on the effective level of the layout.
@@ -707,7 +707,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     // Derive a new layout by collapsing groups of dimensions.
     // Each inner array in `dimGroups` specifies a set of dimensions
     // that are collapsed into a single dimension in the derived layout.
-    DistributeLayoutAttr collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups);
+    DistributeLayoutAttr collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const;
 
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 01cb43b73d5ca..fca57eb85517b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -122,14 +122,32 @@ DistributeLayoutAttr inferShapeCastSourceLayout(MLIRContext *context,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
 
-/// Sets the the layout attribute for result based on a preferred Layout
-/// propagated from consumer
-/// the ouput must be a slice attribute
+/// Sets up layout for reduction operations by creating a SliceAttr for the
+/// result.
+///
+/// This function first attempts to construct a source layout that, when sliced
+/// along reduction dimensions, produces a result layout compatible with the
+/// consumer's preferred layout. This minimizes data redistribution overhead.
+/// The SliceAttr for the result is then created based on the derived source
+/// layout and the specified reduction dimensions.
 SliceAttr
 reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
                          SmallVector<int64_t> reductionDims,
                          DistributeLayoutAttr consumerPreferredLayout);
 
+/// Setup the result layout attribute for a bitcast operation based on element
+/// type bitwidths. This ensures the source layout can always be derived from
+/// the result layout.
+///
+/// When casting from a narrower to a wider element type (srcElemTyBitWidth <
+/// resElemTyBitWidth), the result layout's innermost dimension data sizes
+/// (sg_data, inst_data, lane_data) are scaled up by the bitwidth ratio. This
+/// maintains the invariant that the source layout can be recovered by inverse
+/// scaling during layout inference.
+DistributeLayoutAttr bitCastLayoutSetupRule(DistributeLayoutAttr resLayout,
+                                            int resElemTyBitWidth,
+                                            int srcElemTyBitWidth);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 613651ec7f964..1b50e6dd33802 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -476,39 +476,31 @@ LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
 DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
                                             int64_t instData,
                                             int64_t laneData) {
-  auto sgDataOpt = getSgData();
-  auto instDataOpt = getInstData();
-  auto laneDataOpt = getLaneData();
 
-  SmallVector<int32_t> sgDataVec;
-  SmallVector<int32_t> instDataVec;
-  SmallVector<int32_t> laneDataVec;
+  SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
 
-  if (sgDataOpt)
-    sgDataVec = llvm::to_vector(sgDataOpt.asArrayRef());
-
-  if (instDataOpt)
-    instDataVec = llvm::to_vector(instDataOpt.asArrayRef());
-
-  if (laneDataOpt)
-    laneDataVec = llvm::to_vector(laneDataOpt.asArrayRef());
-
-  if (dim < static_cast<int64_t>(sgDataVec.size()))
+  if (dim < static_cast<int64_t>(sgDataVec.size()) && sgData != -1)
     sgDataVec[dim] = sgData;
-  if (dim < static_cast<int64_t>(instDataVec.size()))
+  if (dim < static_cast<int64_t>(instDataVec.size()) && instData != -1)
     instDataVec[dim] = instData;
-  if (dim < static_cast<int64_t>(laneDataVec.size()))
+  if (dim < static_cast<int64_t>(laneDataVec.size()) && laneData != -1)
     laneDataVec[dim] = laneData;
 
+  SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
+  SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
+  SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
+
   return LayoutAttr::get(
       getContext(), getSgLayout(),
       sgDataVec.empty() ? DenseI32ArrayAttr()
-                        : DenseI32ArrayAttr::get(getContext(), sgDataVec),
+                        : DenseI32ArrayAttr::get(getContext(), sgDataVec32),
       instDataVec.empty() ? DenseI32ArrayAttr()
-                          : DenseI32ArrayAttr::get(getContext(), instDataVec),
+                          : DenseI32ArrayAttr::get(getContext(), instDataVec32),
       getLaneLayout(),
       laneDataVec.empty() ? DenseI32ArrayAttr()
-                          : DenseI32ArrayAttr::get(getContext(), laneDataVec),
+                          : DenseI32ArrayAttr::get(getContext(), laneDataVec32),
       getOrder());
 }
 
@@ -516,9 +508,8 @@ DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
 // Each inner array in `dimGroups` specifies a set of dimensions
 // that are collapsed into a single dimension in the derived layout.
 DistributeLayoutAttr
-LayoutAttr::collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups) {
+LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
 
-  // Extract layout attributes as vectors
   SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
   SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
   SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
@@ -823,16 +814,18 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
 // Each inner array in `dimGroups` specifies a set of dimensions
 // that are collapsed into a single dimension in the derived layout.
 DistributeLayoutAttr
-SliceAttr::collapseDims(ArrayRef<ArrayRef<int64_t>> dimGroups) const {
+SliceAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
 
   // Map the sliced dims from parent space to collapsed space
-  ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
+  SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
 
   // go through dimGroups and map each dim from sliced space to parent space
   SmallVector<SmallVector<int64_t>> adjustedDimGroups;
   for (const auto &group : dimGroups) {
-    SetVector<int64_t> mappedDims = mapDimsFromSlicedSpace(group, sliceDims);
-    adjustedDimGroups.push_back(mappedDims.getArrayRef());
+    SetVector<int64_t> groupSet(group.begin(), group.end());
+    SetVector<int64_t> mappedDims = mapDimsFromSlicedSpace(groupSet, sliceDims);
+    adjustedDimGroups.push_back(
+        SmallVector<int64_t>(mappedDims.begin(), mappedDims.end()));
   }
 
   auto collapsedParent = getParent().collapseDims(adjustedDimGroups);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 1be44480de01d..7c8d0329140da 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -603,7 +603,6 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   if (!resultLayout.isAssigned())
     return;
 
-  VectorType resultTy = llvm::dyn_cast<VectorType>(reduction.getDestType());
   VectorType sourceTy =
       llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
   SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
@@ -691,10 +690,10 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   auto resultLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
 
-  xegpu::DistributeLayoutAttr resLayout = xegpu::inferBroadCastSourceLayout(
+  xegpu::DistributeLayoutAttr srcLayout = xegpu::inferBroadCastSourceLayout(
       broadcast.getContext(), resultLayoutAttr, resShape, srcShape);
 
-  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(resLayout)));
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayout)));
   return;
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index acb5f07893ee2..4a2b3557f1679 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -291,7 +291,7 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
 //       across region boundaries
 //        while preserving a single well-defined use for each definition at the
 //        region-op level.
-bool xegpu::recoverTemporaryLayouts_first(Operation *rootOp) {}
+// bool xegpu::recoverTemporaryLayouts_first(Operation *rootOp) {}
 
 template <typename T, typename>
 void xegpu::removeLayoutAttr(const T &operandOrResult) {
@@ -378,52 +378,38 @@ xegpu::inferBitCastSourceLayout(MLIRContext *context,
                                 xegpu::DistributeLayoutAttr resLayout,
                                 int resElemTyBitWidth, int srcElemTyBitWidth) {
   // the result and source layout must be the same
-  // if resLayout is SliceAttr, we need to first get its root layout
-  xegpu::DistributeLayoutAttr resRootLayout = resLayout;
-  while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resRootLayout)) {
-    resRootLayout = sliceLayout.getParent();
-  }
-  // change the laneData of resRootLayout according to the bitwidth ratio
-  xegpu::LayoutAttr resRootPlainLayout =
-      dyn_cast<xegpu::LayoutAttr>(resRootLayout);
-  SmallVector<int64_t> laneData =
-      resRootPlainLayout.getEffectiveLaneDataAsInt();
+  // only adjust the sg_data, inst_data, lane_data accordingly
+  // based on the bitwidth ratio between source and result element type
+
+  SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
+  size_t dim = sgData.size() - 1;
+  int64_t sgDataValue, instDataValue, laneDataValue;
 
   if (srcElemTyBitWidth >= resElemTyBitWidth) {
     int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
-    laneData[laneData.size() - 1] =
-        laneData[laneData.size() - 1] * bitWidthRatio;
+    sgDataValue = (dim < sgData.size()) ? sgData[dim] * bitWidthRatio : -1;
+    instDataValue =
+        (dim < instData.size()) ? instData[dim] * bitWidthRatio : -1;
+    laneDataValue =
+        (dim < laneData.size()) ? laneData[dim] * bitWidthRatio : -1;
   } else {
     int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
-    assert((laneData[laneData.size() - 2] % bitWidthRatio) == 0 &&
+    assert((laneData[dim] % bitWidthRatio) == 0 &&
            "laneData not divisible by bitWidthRatio");
-    laneData[laneData.size() - 1] =
-        laneData[laneData.size() - 1] / bitWidthRatio;
+    sgDataValue = (dim < sgData.size()) ? sgData[dim] / bitWidthRatio : -1;
+    instDataValue =
+        (dim < instData.size()) ? instData[dim] / bitWidthRatio : -1;
+    laneDataValue =
+        (dim < laneData.size()) ? laneData[dim] / bitWidthRatio : -1;
   }
 
-  // now reconstruct the source layout with updated laneData
-  // by updating the root layout and going throught the slice layers
-  SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
-  xegpu::LayoutAttr proposedSrcLayout = xegpu::LayoutAttr::get(
-      context, resRootPlainLayout.getSgLayout(), resRootPlainLayout.getSgData(),
-      resRootPlainLayout.getInstData(), resRootPlainLayout.getLaneLayout(),
-      DenseI32ArrayAttr::get(context, laneData32),
-      resRootPlainLayout.getOrder());
-
-  // reconstruct slice layers if any
-  // First collect all slice layers from innermost to outermost
-  SmallVector<DenseI64ArrayAttr> sliceDims;
-  xegpu::DistributeLayoutAttr currentLayout = resLayout;
-  while (auto sliceLayout = dyn_cast<xegpu::SliceAttr>(currentLayout)) {
-    sliceDims.push_back(sliceLayout.getDims());
-    currentLayout = sliceLayout.getParent();
-  }
+  // Now set only instData and laneData, preserving sgData
+  xegpu::DistributeLayoutAttr finalSrcLayout;
+  finalSrcLayout =
+      resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
 
-  // Now rebuild from outermost to innermost (reverse order)
-  xegpu::DistributeLayoutAttr finalSrcLayout = proposedSrcLayout;
-  for (auto it = sliceDims.rbegin(); it != sliceDims.rend(); ++it) {
-    finalSrcLayout = xegpu::SliceAttr::get(context, finalSrcLayout, *it);
-  }
   return finalSrcLayout;
 }
 
@@ -471,7 +457,7 @@ xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
 
   // Maps each source dimension to the range of destination dimensions it splits
   // into
-  SmallVector<SmallVector<size_t>> splitDimGroups;
+  SmallVector<SmallVector<int64_t>> splitDimGroups;
 
   auto checkSplitDims = [&](ArrayRef<int64_t> src,
                             ArrayRef<int64_t> dst) -> bool {
@@ -480,7 +466,7 @@ xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
     splitDimGroups.clear();
     size_t srcIdx = 0;
     int64_t accumulatedSize = 1;
-    SmallVector<size_t> currentDstDims;
+    SmallVector<int64_t> currentDstDims;
 
     for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
       if (srcIdx >= src.size())
@@ -531,8 +517,30 @@ xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
     // construct a vector layout with lane_data = [1, ..., 1]
     SmallVector<int64_t> laneData(srcShapeSize, 1);
   }
+
+  // TODO: Complete implementation for other shape cast scenarios
+  return nullptr;
 }
 
+/// Sets up layout for reduction operations by creating a SliceAttr for the
+/// result.
+///
+/// Algorithm Overview:
+/// This function attempts to construct a source layout that, when sliced along
+/// reduction dimensions, produces a result layout compatible with the
+/// consumer's preferred layout. This minimizes data redistribution overhead.
+/// The SliceAttr for the result is created based on the derived source layout
+/// and the specified reduction dimensions.
+///
+/// Strategy:
+/// 1. First, check if the consumer's preferred layout is already a SliceAttr
+///    with matching reduction dimensions. If so, use its parent layout directly
+///    and adjust the sg_data/inst_data acccording to source shape.
+/// 2. If step 1 fails, construct a new layout by distributing
+/// workgroup/subgroup
+///    resources across dimensions, prioritizing alignment with the consumer's
+///    sg_layout for non-reduction dimensions.
+
 xegpu::SliceAttr
 xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
                                 SmallVector<int64_t> reductionDims,
@@ -541,29 +549,40 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
   xegpu::SliceAttr sliceCPL =
       dyn_cast<xegpu::SliceAttr>(consumerPreferredLayout);
 
-  // try to align wiht customer's preferred layout so that the slice layout
-  // structure is preserved, and thus avoid potential data movement acorss sg or
+  // Strategy 1: Try to preserve the consumer's slice layout structure
+  // If the consumer already expects a slice layout with the same reduction
+  // dims, we can directly use its parent layout as our source layout. This
+  // ensures perfect alignment and avoids any data movement across subgroups or
   // lanes.
 
-  const int workgroupSize = 16; // assuming 16 subgroups for now
-  const int subgroupSize = 16;  // assuming 16 lanes per subgroup
-  const int vectorSize = 8;     // assuming 8 elements per vector lane
+  // Hardware constraints (these should ideally be queried from device
+  // capabilities)
+  const int workgroupSize = 16; // Total number of subgroups in a workgroup
+  const int subgroupSize = 16;  // Number of SIMD lanes per subgroup
+  const int vectorSize = 8;     // Elements processed per vector instruction
   int srcShapeSize = srcShape.size();
   xegpu::DistributeLayoutAttr proposedSrcLayout;
   auto context = consumerPreferredLayout.getContext();
-  // if srcShapeSize is less than 2, we cannot proceed
+  // Reduction layout requires at least 2D tensors
   if (srcShapeSize < 2)
     return nullptr;
 
   llvm::errs() << "DEBUG: Entering \n";
 
+  // Initialize layout components:
+  // - sgLayout[i]: Number of subgroups covering dimension i
+  // - sgData[i]: Data elements per subgroup in dimension i (srcShape[i] /
+  // sgLayout[i])
   SmallVector<int64_t> sgLayout(srcShapeSize);
   SmallVector<int64_t> sgData(srcShapeSize);
 
+  // Initialize instruction-level parallelism with SIMD-friendly defaults:
+  // - Last dimension gets subgroupSize (16) to match lane width
+  // - Second-to-last dimension gets vectorSize (8) as starting point
   SmallVector<int64_t> instData(srcShapeSize, 1);
   instData[srcShapeSize - 1] = subgroupSize;
   instData[srcShapeSize - 2] =
-      vectorSize; // assuming 8 elements per instruction as starting point
+      vectorSize; // This will be adjusted based on actual data distribution
   llvm::errs() << "DEBUG: Initial instData = [";
   for (size_t i = 0; i < instData.size(); i++) {
     llvm::errs() << instData[i];
@@ -571,7 +590,10 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
       llvm::errs() << ", ";
   }
   llvm::errs() << "]\n";
-  // construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
+  // Initialize lane-level distribution:
+  // - laneLayout[i]: How lanes are distributed across dimension i
+  //   (innermost dimension gets all subgroupSize lanes)
+  // - laneData[i]: Data elements per lane in dimension i (starts at 1 per lane)
   SmallVector<int64_t> laneLayout(srcShapeSize, 1);
   laneLayout[srcShapeSize - 1] = subgroupSize;
   // construct a vector layout with lane_data = [1, ..., 1]
@@ -582,13 +604,16 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
     if (i < laneLayout.size() - 1)
       llvm::errs() << ", ";
   }
-  llvm::errs() << "]\n";
-
+  // Attempt Strategy 1: Align with consumer's slice structure
   bool failToAlignSliceStruct = false;
   if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims)) {
-
+    // The consumer is already expecting a slice along our reduction dimensions!
+    // Extract the parent layout (the layout before slicing) as our candidate.
     xegpu::DistributeLayoutAttr parentCPL = sliceCPL.getParent();
 
+    // Verify that the parent layout can be adapted to our source shape:
+    // For each dimension, check if srcShape[i] is divisible by the parent's
+    // sg_layout[i]. If so, we can reuse the subgroup distribution pattern
     // for each slice dim in source shape, if the dim size is differnt than the
     // result shape, try to adjust the sg_data/inst_data accordingly.
     SmallVector<int64_t> pcplSgLayout = parentCPL.getEffectiveSgLayoutAsInt();
@@ -664,7 +689,9 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
     for (int i = srcShapeSize - 1; i >= 0; i--) {
       llvm::errs() << "DEBUG: Loop 1, i = " << i << ", is_reduction_dim = "
                    << llvm::is_contained(reductionDims, i) << "\n";
-      // try to align with cplSgLayout first for non-reduction dims
+
+      // For non-reduction dimensions, try to match consumer's sg_layout
+      // This ensures the result after reduction has the expected distribution
       if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
         if (srcShape[i] % cplSgLayout[cplId] == 0) {
           sgLayout[i] = cplSgLayout[cplId];
@@ -678,10 +705,14 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
           continue;
         }
       }
+      // Dimension couldn't be aligned; defer to second pass
       remainingDims.push_back(i);
       llvm::errs() << "DEBUG: Added i = " << i << " to remainingDims\n";
     }
 
+    // Second pass: Distribute remaining subgroups across unhandled dimensions
+    // This handles reduction dimensions and dimensions that didn't align with
+    // consumer
     llvm::errs() << "DEBUG: Starting second loop\n";
     for (int i = srcShapeSize - 1; i >= 0; i--) {
       llvm::errs() << "DEBUG: Loop 2, i = " << i << ", is_remaining_dim = "
@@ -725,10 +756,37 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
       DenseI32ArrayAttr::get(context, laneData32),
       consumerPreferredLayout.getOrder());
 
-  // finally, create the slice layout for reduction source
-  xegpu::SliceAttr reductionSrcLayout =
+  // finally, create the slice layout for reduction result
+  xegpu::SliceAttr reductionResLayout =
       xegpu::SliceAttr::get(context, proposedSrcLayout,
                             DenseI64ArrayAttr::get(context, reductionDims));
 
-  return reductionSrcLayout;
+  return reductionResLayout;
 }
+
+xegpu::DistributeLayoutAttr
+xegpu::bitCastLayoutSetupRule(xegpu::DistributeLayoutAttr resLayout,
+                              int resElemTyBitWidth, int srcElemTyBitWidth) {
+
+  SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
+  size_t dim = sgData.size() - 1;
+  int64_t sgDataValue, instDataValue, laneDataValue;
+
+  if (srcElemTyBitWidth < resElemTyBitWidth) {
+    int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
+    sgDataValue = (dim < sgData.size()) ? sgData[dim] * bitWidthRatio : -1;
+    instDataValue =
+        (dim < instData.size()) ? instData[dim] * bitWidthRatio : -1;
+    laneDataValue =
+        (dim < laneData.size()) ? laneData[dim] * bitWidthRatio : -1;
+  }
+
+  // Now set only instData and laneData, preserving sgData
+  xegpu::DistributeLayoutAttr finalResLayout;
+  finalResLayout =
+      resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
+
+  return finalResLayout;
+}
\ No newline at end of file

>From 1e0542258a34ba8d073a4e5b103221cfcf1e1733 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 24 Dec 2025 22:16:43 +0000
Subject: [PATCH 05/35] add recover temporary layout implementation, not
 compiled yet

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |   9 +-
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    |   4 +-
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 234 +++++++++++++++---
 3 files changed, 209 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 0e13d063498b1..14dbcc8d94630 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -237,7 +237,8 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     "delinearizeId",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
     InterfaceMethod<[{Derive a new layout with sg_data, inst_data and lane_data set to the 
-                      specified values for the given dimension}],
+                      specified values for the given dimension. Passing -1 for any parameter 
+                      preserves its original value.}],
                     "xegpu::DistributeLayoutAttr",
                     "setDimData",
                     (ins "int64_t": $dim,
@@ -521,7 +522,8 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
 
     // Derive a new layout with sg_data, inst_data and lane_data set to the 
-    // specified values for the given dimension
+    // specified values for the given dimension. Passing -1 for any parameter 
+    // preserves its original value.
     DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
     // Derive a new layout by collapsing groups of dimensions.
@@ -701,7 +703,8 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
 
     // Derive a new layout with sg_data, inst_data and lane_data set to the 
-    // specified values for the given dimension
+    // specified values for the given dimension. Passing -1 for any parameter 
+    // preserves its original value.
     DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
     // Derive a new layout by collapsing groups of dimensions.
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index fca57eb85517b..bff005e3701df 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -96,7 +96,7 @@ void removeLayoutAttrs(Operation *op);
 
 /// Infers the source layout attribute for a broadcast operation given the
 /// result layout attribute, result shape, source shape, and broadcasted dims.
-DistributeLayoutAttr inferBroadCastSourceLayout(MLIRContext *context,
+DistributeLayoutAttr inferBroadcastSourceLayout(MLIRContext *context,
                                                 DistributeLayoutAttr resLayout,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
@@ -117,7 +117,7 @@ DistributeLayoutAttr inferBitCastSourceLayout(MLIRContext *context,
 
 /// Infers the source layout attribute for a shape cast operation given the
 /// result layout attribute, result shape, and source shape.
-DistributeLayoutAttr inferShapeCastSourceLayout(MLIRContext *context,
+DistributeLayoutAttr inferShapecastSourceLayout(MLIRContext *context,
                                                 DistributeLayoutAttr resLayout,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 4a2b3557f1679..ccda369accdd6 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -244,24 +245,25 @@ void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
 /// Attach layout attributes to all vector-type operands of operations within
 /// the given operation's region. Reports an error if any vector operand lacks
 /// a layout attribute.
-bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-  auto result = rootOp->walk([&](Operation *op) {
-    for (OpOperand &operand : op->getOpOperands()) {
-      // Layouts are needed for vector type only.
-      if (!isa<VectorType>(operand.get().getType()))
-        continue;
-      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
-      if (!layout) {
-        op->emitError("Could not find layout attribute for operand ")
-            << operand.getOperandNumber() << " of operation " << op->getName();
-        return WalkResult::interrupt();
-      }
-      xegpu::setDistributeLayoutAttr(operand, layout);
-    }
-    return WalkResult::advance();
-  });
-  return !result.wasInterrupted();
-}
+// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+//   auto result = rootOp->walk([&](Operation *op) {
+//     for (OpOperand &operand : op->getOpOperands()) {
+//       // Layouts are needed for vector type only.
+//       if (!isa<VectorType>(operand.get().getType()))
+//         continue;
+//       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+//       if (!layout) {
+//         op->emitError("Could not find layout attribute for operand ")
+//             << operand.getOperandNumber() << " of operation " <<
+//             op->getName();
+//         return WalkResult::interrupt();
+//       }
+//       xegpu::setDistributeLayoutAttr(operand, layout);
+//     }
+//     return WalkResult::advance();
+//   });
+//   return !result.wasInterrupted();
+// }
 
 // Prerequisite for Layout Recovery
 // It relies on the following invariant:
@@ -274,10 +276,11 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
 //     - Only the result of convert_layout is permitted to have no subsequent
 //     use.
 
-// The recover proceeds by scanning the operation in reverse topological orderas
-// follows: Across operations: layouts are propagated from uses to definitions.
-// Within an operation: layouts are propagated from definitions (result) to uses
-// (operands).
+// The recover proceeds by scanning the operation in reverse topological order
+// as follows:
+//    For regular operations: First the result layouts are propagated from uses.
+//      Then the result layouts are propagated to uses (operands).
+//
 //    For region operations (e.g., loops):
 //       - When backward propagation reaches a region op, it sets the layout of
 //       the region op’s results according to use points like regular ops.
@@ -291,7 +294,168 @@ bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
 //       across region boundaries
 //        while preserving a single well-defined use for each definition at the
 //        region-op level.
-// bool xegpu::recoverTemporaryLayouts_first(Operation *rootOp) {}
+
+// the inner function for recoverTemporaryLayouts is a recursive function
+// the input rootOp is the function operation, which is also a region op.
+// it recursivley process the region op in reverse topological order.
+bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
+  rootOp->walk([&](func::FuncOp func) {
+    walkRegionBackward(func.getBody(), [&](Operation *op) {
+      if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
+        // hit the region op after visiting inside region
+        propagateRegionArgsToInits(regionOp);
+      } else if (auto yieldOp =
+                     dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
+        // yield op inside region op
+        propagateRegionResultsToYieldOperands(yieldOp);
+      } else {
+        // if the op is regular op, calling propagateResultsToRegularOperands
+        propagateResultsToRegularOperands(op);
+      }
+    });
+  });
+}
+
+static void walkRegionBackward(Region &region,
+                               llvm::function_ref<void(Operation *)> visit) {
+  // blocks: back -> front
+  for (Block &block : llvm::reverse(region)) {
+    // ops: back -> front, early-inc so visit() may erase current op safely
+    for (Operation &op : llvm::reverse(block)) {
+      // make sure we first visit inside the region op (so yield op first)
+      // and then move to region op itself
+      for (Region &nested : llvm::reverse(op.getRegions()))
+        walkRegionBackward(nested, visit);
+
+      visit(&op);
+    }
+  }
+}
+
+// For regular operations: First the result layouts are propagated from uses.
+// Then the result layouts are propagated to uses (operands).
+static void propagateResultsToRegularOperands(Operation *op) {
+  OpResult result = op->getOpResults()[0];
+  auto resLayout = xegpu::getDistributeLayoutAttr(result);
+  assert(resLayout &&
+         "result layout must be defined before propagating to uses");
+
+  // if op is reduction op, call inferReductionSourceLayout
+  if (auto reduceOp = dyn_cast<vector::MultiDimReductionOp>(op)) {
+    SmallVector<int64_t> reduceDims =
+        llvm::to_vector(reduceOp.getReductionDims().getAsValueRange<int64_t>());
+    auto srcLayout = xegpu::inferReductionSourceLayout(resLayout, reduceDims);
+    // set the layout to the operand
+    xegpu::setTemporaryLayout(reduceOp.getSource(), srcLayout);
+    xegpu::setTemporaryLayout(reduceOp.getAcc(), resLayout);
+    return;
+  }
+
+  // if op is broadcast op, call inferBroadcastSourceLayout
+  if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
+    ArrayRef<int64_t> resShape =
+        llvm::cast<VectorType>(broadcastOp.getResult().getType()).getShape();
+    ArrayRef<int64_t> srcShape =
+        llvm::cast<VectorType>(broadcastOp.getSource().getType()).getShape();
+    auto srcLayout =
+        xegpu::inferBroadcastSourceLayout(resLayout, resShape, srcShape);
+    // set the layout to the operand
+    xegpu::setTemporaryLayout(broadcastOp.getSource(), srcLayout);
+    return;
+  }
+
+  // if op is bitcast op, call inferBitCastSourceLayout
+  if (auto bitcastOp = dyn_cast<vector::BitCastOp>(op)) {
+    int resElemTyBitWidth =
+        llvm::cast<VectorType>(bitcastOp.getResult().getType())
+            .getElementTypeBitWidth();
+    int srcElemTyBitWidth =
+        llvm::cast<VectorType>(bitcastOp.getSource().getType())
+            .getElementTypeBitWidth();
+    auto srcLayout = xegpu::inferBitCastSourceLayout(
+        op->getContext(), resLayout, resElemTyBitWidth, srcElemTyBitWidth);
+    // set the layout to the operand
+    xegpu::setTemporaryLayout(bitcastOp.getSource(), srcLayout);
+    return;
+  }
+
+  // if op is shape_cast op, call inferShapecastSourceLayout
+  if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
+    ArrayRef<int64_t> resShape =
+        llvm::cast<VectorType>(shapeCastOp.getResult().getType()).getShape();
+    ArrayRef<int64_t> srcShape =
+        llvm::cast<VectorType>(shapeCastOp.getSource().getType()).getShape();
+    auto srcLayout =
+        xegpu::inferShapecastSourceLayout(resLayout, resShape, srcShape);
+    // set the layout to the operand
+    xegpu::setTemporaryLayout(shapeCastOp.getSource(), srcLayout);
+    return;
+  }
+
+  // if op is a anchor op, no need to do anything
+  if (isa<xegpu::AnchorLayoutInterface>(op)) {
+    return;
+  }
+
+  // for other regular ops, propagate the result layout to all vector operands
+  for (OpOperand &opr : op->getOpOperands()) {
+    // Layouts are needed for vector type only.
+    if (!isa<VectorType>(opr.get().getType()))
+      continue;
+    xegpu::setTemporaryLayout(opr, resLayout);
+  }
+}
+
+static void propagateRegionResultsToYieldOperands(
+    mlir::RegionBranchTerminatorOpInterface yieldOp) {
+  llvm::SmallVector<mlir::RegionSuccessor> successors;
+  llvm::SmallVector<mlir::Attribute> operands(yieldOp->getNumOperands(),
+                                              nullptr);
+  yieldOp.getSuccessorRegions(operands, successors);
+
+  for (mlir::RegionSuccessor &successor : successors) {
+    // find out the successor which is the parent region of yieldOp
+    if (successor.getSuccessorRegion() != yieldOp->getParentRegion()) //????//
+      continue;
+    // propagate the layout from region result to yield operands
+    for (unsigned i = 0; i < successor.getSuccessorInputs().size(); ++i) {
+      Value regionResult = successor.getSuccessorInputs()[i]; // region argument
+      Value yieldOperand = yieldOp->getOperand(i);            // yield operand
+
+      auto layout = xegpu::getDistributeLayoutAttr(regionResult);
+      assert(
+          layout &&
+          "region result layout must be defined before propagating to yield");
+      xegpu::setTemporaryLayout(yieldOperand, layout);
+    }
+  }
+}
+
+void propagateRegionArgsToInits(mlir::RegionBranchOpInterface *regionOp) {
+
+  // Get entry successors (regions that can be entered initially)
+  SmallVector<RegionSuccessor> successors;
+  regionOp.getEntrySuccessorRegions(/*operands=*/ArrayRef<Attribute>(),
+                                    successors);
+
+  // For each possible entry region, get the operands forwarded to it
+  for (RegionSuccessor &successor : successors) {
+    OperandRange initOperands = regionOp.getEntrySuccessorOperands(successor);
+    // initOperands are the initialization arguments for this successor
+    // iterate the region arguments
+    for (unsigned i = 0; i < successor.getSuccessorRegion()->getNumArguments();
+         ++i) {
+      Value regionArg =
+          successor.getSuccessorRegion()->getArgument(i); // region argument
+      Value initOperand = initOperands[i];                // init operand
+      auto layout = xegpu::getDistributeLayoutAttr(regionArg);
+      assert(
+          layout &&
+          "region argument layout must be defined before propagating to init");
+      xegpu::setTemporaryLayout(initOperand, layout);
+    }
+  }
+}
 
 template <typename T, typename>
 void xegpu::removeLayoutAttr(const T &operandOrResult) {
@@ -328,9 +492,10 @@ void xegpu::removeLayoutAttrs(Operation *op) {
 
 /// Infers the source layout attribute for a broadcast operation given the
 /// result layout attribute, result shape, source shape.
-xegpu::DistributeLayoutAttr xegpu::inferBroadCastSourceLayout(
-    MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
-    ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
+xegpu::DistributeLayoutAttr
+xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                                  ArrayRef<int64_t> resShape,
+                                  ArrayRef<int64_t> srcShape) {
 
   SmallVector<int64_t> bcastDims;
   auto returnLayout = resLayout;
@@ -345,7 +510,8 @@ xegpu::DistributeLayoutAttr xegpu::inferBroadCastSourceLayout(
 
     // create a slice layout for the source
     returnLayout = xegpu::SliceAttr::get(
-        context, resLayout, DenseI64ArrayAttr::get(context, bcastDims));
+        resLayout.getContext(), resLayout,
+        DenseI64ArrayAttr::get(resLayout.getContext(), bcastDims));
   }
   return returnLayout;
 }
@@ -415,9 +581,10 @@ xegpu::inferBitCastSourceLayout(MLIRContext *context,
 
 /// Infers the source layout attribute for a shape cast operation given the
 /// result layout attribute, result shape, and source shape.
-xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
-    MLIRContext *context, xegpu::DistributeLayoutAttr resLayout,
-    ArrayRef<int64_t> resShape, ArrayRef<int64_t> srcShape) {
+xegpu::DistributeLayoutAttr
+xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                                  ArrayRef<int64_t> resShape,
+                                  ArrayRef<int64_t> srcShape) {
 
   // there are two use cases:
   // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
@@ -449,9 +616,10 @@ xegpu::DistributeLayoutAttr xegpu::inferShapeCastSourceLayout(
 
   if (checkOnlyExpandUnitDims(srcShape, resShape)) {
     // create a slice layout for the source by removing the expanded unit dims
-    auto sliceDimsAttr =
-        DenseI64ArrayAttr::get(context, ArrayRef<int64_t>(expandedUnitDims));
-    auto srcLayout = xegpu::SliceAttr::get(context, resLayout, sliceDimsAttr);
+    auto sliceDimsAttr = DenseI64ArrayAttr::get(
+        resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
+    auto srcLayout =
+        xegpu::SliceAttr::get(resLayout.getContext(), resLayout, sliceDimsAttr);
     return srcLayout;
   }
 

>From e73664ce582e072f3444cf50e71ca6d73a7bbe3c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 24 Dec 2025 22:25:07 +0000
Subject: [PATCH 06/35] remove debug print

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 150 +++++-------------
 1 file changed, 38 insertions(+), 112 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index ccda369accdd6..a995249d87e95 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -717,13 +717,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
   xegpu::SliceAttr sliceCPL =
       dyn_cast<xegpu::SliceAttr>(consumerPreferredLayout);
 
-  // Strategy 1: Try to preserve the consumer's slice layout structure
-  // If the consumer already expects a slice layout with the same reduction
-  // dims, we can directly use its parent layout as our source layout. This
-  // ensures perfect alignment and avoids any data movement across subgroups or
-  // lanes.
-
-  // Hardware constraints (these should ideally be queried from device
+  // Hardware constraints (TODO: these should ideally be queried from device
   // capabilities)
   const int workgroupSize = 16; // Total number of subgroups in a workgroup
   const int subgroupSize = 16;  // Number of SIMD lanes per subgroup
@@ -735,44 +729,23 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
   if (srcShapeSize < 2)
     return nullptr;
 
-  llvm::errs() << "DEBUG: Entering \n";
-
-  // Initialize layout components:
-  // - sgLayout[i]: Number of subgroups covering dimension i
-  // - sgData[i]: Data elements per subgroup in dimension i (srcShape[i] /
-  // sgLayout[i])
   SmallVector<int64_t> sgLayout(srcShapeSize);
   SmallVector<int64_t> sgData(srcShapeSize);
 
-  // Initialize instruction-level parallelism with SIMD-friendly defaults:
-  // - Last dimension gets subgroupSize (16) to match lane width
-  // - Second-to-last dimension gets vectorSize (8) as starting point
   SmallVector<int64_t> instData(srcShapeSize, 1);
   instData[srcShapeSize - 1] = subgroupSize;
   instData[srcShapeSize - 2] =
       vectorSize; // This will be adjusted based on actual data distribution
-  llvm::errs() << "DEBUG: Initial instData = [";
-  for (size_t i = 0; i < instData.size(); i++) {
-    llvm::errs() << instData[i];
-    if (i < instData.size() - 1)
-      llvm::errs() << ", ";
-  }
-  llvm::errs() << "]\n";
-  // Initialize lane-level distribution:
-  // - laneLayout[i]: How lanes are distributed across dimension i
-  //   (innermost dimension gets all subgroupSize lanes)
-  // - laneData[i]: Data elements per lane in dimension i (starts at 1 per lane)
+
   SmallVector<int64_t> laneLayout(srcShapeSize, 1);
   laneLayout[srcShapeSize - 1] = subgroupSize;
-  // construct a vector layout with lane_data = [1, ..., 1]
   SmallVector<int64_t> laneData(srcShapeSize, 1);
-  llvm::errs() << "DEBUG: laneLayout = [";
-  for (size_t i = 0; i < laneLayout.size(); i++) {
-    llvm::errs() << laneLayout[i];
-    if (i < laneLayout.size() - 1)
-      llvm::errs() << ", ";
-  }
-  // Attempt Strategy 1: Align with consumer's slice structure
+
+  // Strategy 1: Try to preserve the consumer's slice layout structure
+  // If the consumer already expects a slice layout with the same reduction
+  // dims, we can directly use its parent layout as our source layout. This
+  // ensures perfect alignment and avoids any data movement across subgroups or
+  // lanes.
   bool failToAlignSliceStruct = false;
   if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims)) {
     // The consumer is already expecting a slice along our reduction dimensions!
@@ -782,7 +755,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
     // Verify that the parent layout can be adapted to our source shape:
     // For each dimension, check if srcShape[i] is divisible by the parent's
     // sg_layout[i]. If so, we can reuse the subgroup distribution pattern
-    // for each slice dim in source shape, if the dim size is differnt than the
+    // for each slice dim in source shape, if the dim size is different than the
     // result shape, try to adjust the sg_data/inst_data accordingly.
     SmallVector<int64_t> pcplSgLayout = parentCPL.getEffectiveSgLayoutAsInt();
     SmallVector<int64_t> pcplLaneLayout =
@@ -794,16 +767,6 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
 
     proposedSrcLayout = parentCPL;
 
-    llvm::errs() << "DEBUG: srcShapeSize = " << srcShapeSize << "\n";
-    llvm::errs() << "DEBUG: parentCPL rank = " << parentCPL.getRank() << "\n";
-    llvm::errs() << "DEBUG: srcShape = [";
-    for (int i = 0; i < srcShapeSize; i++) {
-      llvm::errs() << srcShape[i];
-      if (i < srcShapeSize - 1)
-        llvm::errs() << ", ";
-    }
-    llvm::errs() << "]\n";
-
     if (pcplSgLayout.size() == static_cast<size_t>(srcShapeSize)) {
       for (int i = 0; i < srcShapeSize; i++) {
         if (srcShape[i] % pcplSgLayout[i] == 0) {
@@ -832,85 +795,49 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
     failToAlignSliceStruct = true;
   }
 
+  // Strategy 2: Construct a new layout aligned with consumer's sg_layout for
+  // the result (non-reduction dims) then distribute remaining subgroups across
+  // other dimensions
   if (failToAlignSliceStruct) {
-
-    // try to align the sg layout
     SmallVector<int64_t> cplSgLayout =
         consumerPreferredLayout.getEffectiveSgLayoutAsInt();
-    llvm::errs() << "DEBUG: cplSgLayout size = " << cplSgLayout.size() << "\n";
-    // if sg layout doesn't cover all the sg ids, distribute rest to
-    // non-reduction dims
     int remainingSgCount = workgroupSize;
-
     SmallVector<int64_t> remainingDims;
-    // print debug info for consumerPreferredLayout and cplSgLayout
-    llvm::errs() << "DEBUG: consumerPreferredLayout sgLayout = [";
-    auto cplSgLayoutFull = consumerPreferredLayout.getEffectiveSgLayoutAsInt();
-    for (size_t i = 0; i < cplSgLayoutFull.size(); i++) {
-      llvm::errs() << cplSgLayoutFull[i];
-      if (i < cplSgLayoutFull.size() - 1)
-        llvm::errs() << ", ";
-    }
-    // if cplSgLayout is not empty, try to align the sg layout first
     int cplId = cplSgLayout.size() - 1;
-    llvm::errs() << "DEBUG: Starting first loop, cplId = " << cplId << "\n";
-    for (int i = srcShapeSize - 1; i >= 0; i--) {
-      llvm::errs() << "DEBUG: Loop 1, i = " << i << ", is_reduction_dim = "
-                   << llvm::is_contained(reductionDims, i) << "\n";
-
-      // For non-reduction dimensions, try to match consumer's sg_layout
-      // This ensures the result after reduction has the expected distribution
-      if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
-        if (srcShape[i] % cplSgLayout[cplId] == 0) {
-          sgLayout[i] = cplSgLayout[cplId];
-          sgData[i] = srcShape[i] / sgLayout[i];
-          instData[i] = std::min(instData[i], sgData[i]);
-          remainingSgCount /= sgLayout[i];
-          llvm::errs() << "DEBUG: Set sgLayout[" << i << "] = " << sgLayout[i]
-                       << ", sgData[" << i << "] = " << sgData[i]
-                       << ", remainingSgCount = " << remainingSgCount << "\n";
-          cplId--;
-          continue;
-        }
-      }
-      // Dimension couldn't be aligned; defer to second pass
-      remainingDims.push_back(i);
-      llvm::errs() << "DEBUG: Added i = " << i << " to remainingDims\n";
-    }
 
-    // Second pass: Distribute remaining subgroups across unhandled dimensions
-    // This handles reduction dimensions and dimensions that didn't align with
-    // consumer
-    llvm::errs() << "DEBUG: Starting second loop\n";
-    for (int i = srcShapeSize - 1; i >= 0; i--) {
-      llvm::errs() << "DEBUG: Loop 2, i = " << i << ", is_remaining_dim = "
-                   << llvm::is_contained(remainingDims, i) << "\n";
-      if (llvm::is_contained(remainingDims, i)) {
-
-        llvm::errs() << "DEBUG: Before Set sgLayout[" << i
-                     << "] = " << sgLayout[i] << ", sgData[" << i
-                     << "] = " << sgData[i]
-                     << ", remainingSgCount = " << remainingSgCount << "\n";
-
-        sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
-                               static_cast<int64_t>(remainingSgCount));
+    // For non-reduction dimensions, try to match consumer's sg_layout
+    // This ensures the result after reduction has the expected distribution
+    if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
+      if (srcShape[i] % cplSgLayout[cplId] == 0) {
+        sgLayout[i] = cplSgLayout[cplId];
         sgData[i] = srcShape[i] / sgLayout[i];
         instData[i] = std::min(instData[i], sgData[i]);
         remainingSgCount /= sgLayout[i];
+        cplId--;
+        continue;
+      }
+    }
+    remainingDims.push_back(i);
+  }
 
-        llvm::errs() << "DEBUG: After Set sgLayout[" << i
-                     << "] = " << sgLayout[i] << ", sgData[" << i
-                     << "] = " << sgData[i]
-                     << ", remainingSgCount = " << remainingSgCount << "\n";
+  // Second pass: Distribute remaining subgroups across unhandled dimensions
+  // This handles reduction dimensions and dimensions that didn't align with
+  // consumer
+  for (int i = srcShapeSize - 1; i >= 0; i--) {
 
-        if (remainingSgCount == 1) {
-          llvm::errs() << "DEBUG: Breaking from loop 2, remainingSgCount = 1\n";
-          break;
-        }
-      }
+    sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
+                           static_cast<int64_t>(remainingSgCount));
+    sgData[i] = srcShape[i] / sgLayout[i];
+    instData[i] = std::min(instData[i], sgData[i]);
+    remainingSgCount /= sgLayout[i];
+
+    if (remainingSgCount == 1) {
+      break;
+    }
+  }
     }
   }
-  // Convert int64_t vectors to int32_t for DenseI32ArrayAttr
+
   SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
   SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
   SmallVector<int32_t> instData32(instData.begin(), instData.end());
@@ -924,7 +851,6 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
       DenseI32ArrayAttr::get(context, laneData32),
       consumerPreferredLayout.getOrder());
 
-  // finally, create the slice layout for reduction result
   xegpu::SliceAttr reductionResLayout =
       xegpu::SliceAttr::get(context, proposedSrcLayout,
                             DenseI64ArrayAttr::get(context, reductionDims));

>From 5565f60ff77e6d87bf780fab11d370054a3c59f0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 23 Jan 2026 19:29:08 +0000
Subject: [PATCH 07/35] improve reductionSetupRule

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 138 +++++++++---------
 1 file changed, 73 insertions(+), 65 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 327a1901a1bea..fb416b4c81f15 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -637,7 +637,7 @@ xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
                                   ArrayRef<int64_t> resShape,
                                   ArrayRef<int64_t> srcShape) {
 
-  // there are two use cases:
+  // there are three use cases:
   // 1. expand dims of low-rank dimensions (e.g., 1D to 2D): to set up the
   // tensor before broadcast
   // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
@@ -763,10 +763,10 @@ xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 xegpu::SliceAttr
 xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
                                 SmallVector<int64_t> reductionDims,
-                                DistributeLayoutAttr consumerPreferredLayout) {
+                                DistributeLayoutAttr consumerLayout) {
 
-  xegpu::SliceAttr sliceCPL =
-      dyn_cast<xegpu::SliceAttr>(consumerPreferredLayout);
+  xegpu::SliceAttr consumerSliceLayout =
+      dyn_cast<xegpu::SliceAttr>(consumerLayout);
 
   // Hardware constraints (TODO: these should ideally be queried from device
   // capabilities)
@@ -775,7 +775,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
   const int vectorSize = 8;     // Elements processed per vector instruction
   int srcShapeSize = srcShape.size();
   xegpu::DistributeLayoutAttr proposedSrcLayout;
-  auto context = consumerPreferredLayout.getContext();
+  auto context = consumerLayout.getContext();
   // Reduction layout requires at least 2D tensors
   if (srcShapeSize < 2)
     return nullptr;
@@ -797,58 +797,69 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
   // dims, we can directly use its parent layout as our source layout. This
   // ensures best alignment and avoids any data movement across subgroups
   // and lanes.
-  bool failToAlignSliceStruct = false;
-  if (sliceCPL && sliceCPL.getDims().asArrayRef().equals(reductionDims) &&
-      sliceCPL.getParent().getRank() == srcShapeSize) {
-    // The consumer is expecting a slice along the current reduction dimensions
-    xegpu::DistributeLayoutAttr parentCPL = sliceCPL.getParent();
 
+  auto canPreserveSliceLayout =
+      [&](ArrayRef<int64_t> srcShape, SmallVector<int64_t> reductionDims,
+          DistributeLayoutAttr consumerLayout) -> bool {
     // Verify that the consumer layout can be adapted to the source shape:
     // For each dimension, check if srcShape[i] is divisible by the parent's
-    // sg_layout[i]. If so, we can reuse the subgroup distribution pattern
+    // sg_layout[i] and lane_layout[i]. If so, sg_layout and lane_layout can be
+    // reused.
+
+    if (!consumerSliceLayout)
+      return false;
+    if (!consumerSliceLayout.getDims().asArrayRef().equals(reductionDims))
+      return false;
+    xegpu::DistributeLayoutAttr parentLayout = consumerSliceLayout.getParent();
+    if (!parentLayout.getRank() == srcShapeSize)
+      return false;
+
+    SmallVector<int64_t> parentSgLayout =
+        parentLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> parentLaneLayout =
+        parentLayout.getEffectiveLaneLayoutAsInt();
+
+    if (parentSgLayout.size() != static_cast<size_t>(srcShapeSize))
+      return false;
+    if (parentLaneLayout.size() != static_cast<size_t>(srcShapeSize))
+      return false;
+    for (int i = 0; i < srcShapeSize; i++) {
+      if (srcShape[i] % parentSgLayout[i] != 0)
+        return false;
+      if (instData[i] % parentLaneLayout[i] != 0)
+        return false;
+    }
+    return true;
+  };
+
+  if (canPreserveSliceLayout(srcShape, reductionDims, consumerLayout)) {
     // for each slice dim in source shape, if the dim size is different than the
     // result shape, try to adjust the sg_data/inst_data accordingly.
-    SmallVector<int64_t> pcplSgLayout = parentCPL.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> pcplLaneLayout =
-        parentCPL.getEffectiveLaneLayoutAsInt();
-    SmallVector<int64_t> pcplLaneData = parentCPL.getEffectiveLaneDataAsInt();
-
-    proposedSrcLayout = parentCPL;
-
-    if (pcplSgLayout.size() == static_cast<size_t>(srcShapeSize)) {
-      for (int i = 0; i < srcShapeSize; i++) {
-        if (srcShape[i] % pcplSgLayout[i] == 0) {
-          sgLayout[i] = pcplSgLayout[i];
-          sgData[i] = srcShape[i] / sgLayout[i];
-          instData[i] = std::min(instData[i], sgData[i]);
-        } else {
-          failToAlignSliceStruct = true;
-          break;
-        }
-      }
+    SmallVector<int64_t> parentSgLayout =
+        consumerSliceLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> parentLaneLayout =
+        consumerSliceLayout.getEffectiveLaneLayoutAsInt();
+    SmallVector<int64_t> parentLaneData =
+        consumerSliceLayout.getEffectiveLaneDataAsInt();
+
+    for (int i = 0; i < srcShapeSize; i++) {
+      sgLayout[i] = parentSgLayout[i];
+      sgData[i] = srcShape[i] / sgLayout[i];
+      instData[i] = std::min(instData[i], sgData[i]);
+      laneLayout[i] = parentLaneLayout[i];
+      laneData[i] = parentLaneData[i];
     }
 
-    if (pcplLaneLayout.size() == static_cast<size_t>(srcShapeSize)) {
-      for (int i = 0; i < srcShapeSize; i++) {
-        if (instData[i] % pcplLaneLayout[i] == 0) {
-          laneLayout[i] = pcplLaneLayout[i];
-          laneData[i] = pcplLaneData[i];
-        } else {
-          failToAlignSliceStruct = true;
-          break;
-        }
-      }
-    }
   } else {
-    failToAlignSliceStruct = true;
-  }
+    // Strategy 2: Construct a new layout aligned with consumer's sg_layout for
+    // the result (non-reduction dims) then distribute remaining subgroups
+    // across reduced dimensions
 
-  // Strategy 2: Construct a new layout aligned with consumer's sg_layout for
-  // the result (non-reduction dims) then distribute remaining subgroups across
-  // other dimensions
-  if (failToAlignSliceStruct) {
     SmallVector<int64_t> cplSgLayout =
-        consumerPreferredLayout.getEffectiveSgLayoutAsInt();
+        consumerLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> cplSgData = consumerLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> cplInstData =
+        consumerLayout.getEffectiveInstDataAsInt();
     int remainingSgCount = workgroupSize;
     SmallVector<int64_t> remainingDims;
     int cplId = cplSgLayout.size() - 1;
@@ -856,31 +867,29 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
     // This ensures the result after reduction has the expected distribution
     for (int i = srcShapeSize - 1; i >= 0; i--) {
       if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
-        if (srcShape[i] % cplSgLayout[cplId] == 0) {
-          sgLayout[i] = cplSgLayout[cplId];
-          sgData[i] = srcShape[i] / sgLayout[i];
-          instData[i] = std::min(instData[i], sgData[i]);
-          remainingSgCount /= sgLayout[i];
-          cplId--;
-          continue;
-        }
+        assert((srcShape[i] % cplSgLayout[cplId] == 0) &&
+               "source shape not divisible by consumer sg_layout");
+        sgLayout[i] = cplSgLayout[cplId];
+        sgData[i] = srcShape[i] / sgLayout[i];
+        instData[i] = std::min(cplInstData[cplId], sgData[i]);
+        remainingSgCount /= sgLayout[i];
+        cplId--;
       }
-      remainingDims.push_back(i);
     }
 
     // Second pass: Distribute remaining subgroups across unhandled dimensions
     // This handles reduction dimensions and dimensions that didn't align with
     // consumer
     for (int i = srcShapeSize - 1; i >= 0; i--) {
-      sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
-                             static_cast<int64_t>(remainingSgCount));
-      sgData[i] = srcShape[i] / sgLayout[i];
-      instData[i] = std::min(instData[i], sgData[i]);
-      remainingSgCount /= sgLayout[i];
-      if (remainingSgCount == 1) {
-        break;
+      if (llvm::is_contained(reductionDims, i)) {
+        sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
+                               static_cast<int64_t>(remainingSgCount));
+        sgData[i] = srcShape[i] / sgLayout[i];
+        instData[i] = std::min(instData[i], sgData[i]);
+        remainingSgCount /= sgLayout[i];
       }
     }
+    assert(remainingSgCount == 1 && "not all subgroups have been distributed");
   }
 
   SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
@@ -893,8 +902,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
       DenseI32ArrayAttr::get(context, sgData32),
       DenseI32ArrayAttr::get(context, instData32),
       DenseI32ArrayAttr::get(context, laneLayout32),
-      DenseI32ArrayAttr::get(context, laneData32),
-      consumerPreferredLayout.getOrder());
+      DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
   xegpu::SliceAttr reductionResLayout =
       xegpu::SliceAttr::get(context, proposedSrcLayout,
                             DenseI64ArrayAttr::get(context, reductionDims));

>From 540d1e0af9a057e574c2969aff770478738ec1b1 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sun, 25 Jan 2026 04:51:39 +0000
Subject: [PATCH 08/35] add layoutKind parameter for SetupResultLayout
 functions

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    |  63 +--
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  44 +-
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |   1 +
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     |   1 +
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 184 +++----
 .../Transforms/XeGPUSubgroupDistribute.cpp    |   1 +
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |   1 +
 .../Transforms/XeGPUWgToSgDistribute.cpp      |   1 +
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 473 +++++++-----------
 9 files changed, 303 insertions(+), 466 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 7ea949ceaaa94..206b3d85df30b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -11,8 +11,10 @@
 #define MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
 
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
+
 namespace mlir {
 
 class VectorType;
@@ -31,47 +33,6 @@ class TensorDescType;
 
 namespace xegpu {
 
-/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
-std::string getTemporaryLayoutName(const OpOperand &operand);
-
-/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
-std::string getTemporaryLayoutName(const OpResult result);
-
-/// Retrieves the DistributeLayoutAttr associated with a given Value. For
-/// TensorDescType values, the DistributeLayoutAttr is extracted from the
-/// TensorDescType itself. For other values, it is obtained from the attributes
-/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
-/// found.
-DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
-
-/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
-/// will first check the operand_layout_{id} of the owner operation. If not
-/// found, it will check the operand itself and its defining op.
-DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpResult &Result,
-                             const DistributeLayoutAttr layout);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpOperand &opr,
-                             const DistributeLayoutAttr layout);
-
-/// get and set distribute layout attribute for non-anchor operations
-/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
-template <typename T,
-          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
-                                      std::is_same_v<T, OpResult>>>
-DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult);
-
-template <typename T,
-          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
-                                      std::is_same_v<T, OpResult>>>
-void setTemporaryLayout(const T &operandOrResult,
-                        const DistributeLayoutAttr layout);
-
 /// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
 /// OpResult of of the given operation. If the operation contains regions, it is
 /// also applied recursively to the contained operations operation.
@@ -124,7 +85,7 @@ DistributeLayoutAttr inferBitCastSourceLayout(DistributeLayoutAttr resLayout,
 
 /// Infers the source layout attribute for a shape cast operation given the
 /// result layout attribute, result shape, and source shape.
-DistributeLayoutAttr inferShapecastSourceLayout(DistributeLayoutAttr resLayout,
+DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
 
@@ -136,10 +97,10 @@ DistributeLayoutAttr inferShapecastSourceLayout(DistributeLayoutAttr resLayout,
 /// consumer's preferred layout. This minimizes data redistribution overhead.
 /// The SliceAttr for the result is then created based on the derived source
 /// layout and the specified reduction dimensions.
-SliceAttr
-reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
-                         SmallVector<int64_t> reductionDims,
-                         DistributeLayoutAttr consumerPreferredLayout);
+SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
+                                     ArrayRef<int64_t> srcShape,
+                                     DistributeLayoutAttr consumerLayout,
+                                     SmallVector<int64_t> reductionDims);
 
 /// Setup the result layout attribute for a bitcast operation based on element
 /// type bitwidths. This ensures the source layout can always be derived from
@@ -147,12 +108,14 @@ reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
 ///
 /// When casting from a narrower to a wider element type (srcElemTyBitWidth <
 /// resElemTyBitWidth), the result layout's innermost dimension data sizes
-/// (sg_data, inst_data, lane_data) are scaled up by the bitwidth ratio. This
+/// (inst_data, lane_data) are scaled up by the bitwidth ratio. This
 /// maintains the invariant that the source layout can be recovered by inverse
 /// scaling during layout inference.
-DistributeLayoutAttr bitCastLayoutSetupRule(DistributeLayoutAttr resLayout,
-                                            int resElemTyBitWidth,
-                                            int srcElemTyBitWidth);
+DistributeLayoutAttr
+bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
+                         ArrayRef<int64_t> srcShape,
+                         DistributeLayoutAttr consumerLayout,
+                         int resElemTyBitWidth, int srcElemTyBitWidth);
 
 } // namespace xegpu
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 3dbbe7e4c5dff..0f1ca8e38c873 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -9,7 +9,6 @@
 #ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
 #define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
 
-#include "XeGPULayoutUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
@@ -31,6 +30,8 @@ class TensorDescType;
 
 namespace xegpu {
 
+enum class LayoutKind { Lane, InstData, Subgroup };
+
 /// Flatten a set of ValueRange into a single SmallVector<Value>
 SmallVector<Value> flattenValues(ArrayRef<ValueRange> values);
 
@@ -119,6 +120,47 @@ template <typename T>
 int getLargestDivisor(T dim, ArrayRef<T> candidates,
                       ArrayRef<T> candidateMultiples = {});
 
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpResult &Result,
+                             const DistributeLayoutAttr layout);
+
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpOperand &opr,
+                             const DistributeLayoutAttr layout);
+
+/// Retrieves the DistributeLayoutAttr associated with a given Value. For
+/// TensorDescType values, the DistributeLayoutAttr is extracted from the
+/// TensorDescType itself. For other values, it is obtained from the attributes
+/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
+/// found.
+DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
+
+/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
+/// will first check the operand_layout_{id} of the owner operation. If not
+/// found, it will check the operand itself and its defining op.
+DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
+
+/// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
+std::string getTemporaryLayoutName(const OpOperand &operand);
+
+/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
+std::string getTemporaryLayoutName(const OpResult result);
+
+/// get and set distribute layout attribute for non-anchor operations
+/// (and offsets/masks of load/store ops before we get rid of their temp attrs)
+template <typename T,
+          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                      std::is_same_v<T, OpResult>>>
+DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult);
+
+template <typename T,
+          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                      std::is_same_v<T, OpResult>>>
+void setTemporaryLayout(const T &operandOrResult,
+                        const DistributeLayoutAttr layout);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 931834ba16d9a..a50c955ea83b5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Pass/PassManager.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 6a3e533fb2df4..9a8925d357d25 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 36fe26f15049f..8f49647b153a4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -53,8 +54,6 @@ using namespace mlir::dataflow;
 
 namespace {
 
-enum class LayoutKind { Lane, InstData, Subgroup };
-
 //===----------------------------------------------------------------------===//
 // LayoutInfo
 //===----------------------------------------------------------------------===//
@@ -370,7 +369,7 @@ getSIMTLayoutInfoForDPASOperand(VectorType vectorTy, unsigned operandNum,
 class LayoutInfoPropagation
     : public SparseBackwardDataFlowAnalysis<LayoutInfoLattice> {
 private:
-  LayoutKind layoutKind;
+  xegpu::LayoutKind layoutKind;
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
@@ -426,7 +425,7 @@ class LayoutInfoPropagation
 public:
   LayoutInfoPropagation(DataFlowSolver &solver,
                         SymbolTableCollection &symbolTable,
-                        LayoutKind layoutKind)
+                        xegpu::LayoutKind layoutKind)
       : SparseBackwardDataFlowAnalysis(solver, symbolTable),
         layoutKind(layoutKind) {}
   using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
@@ -520,12 +519,12 @@ bool LayoutInfoPropagation::hasParamsOfLayoutKind(
   if (anchorLayout == nullptr) {
     return false;
   }
-  if (layoutKind == LayoutKind::InstData) {
+  if (layoutKind == xegpu::LayoutKind::InstData) {
     return !(anchorLayout.getEffectiveInstDataAsInt().empty());
-  } else if (layoutKind == LayoutKind::Lane) {
+  } else if (layoutKind == xegpu::LayoutKind::Lane) {
     return !(anchorLayout.getEffectiveLaneLayoutAsInt().empty() ||
              anchorLayout.getEffectiveLaneDataAsInt().empty());
-  } else if (layoutKind == LayoutKind::Subgroup) {
+  } else if (layoutKind == xegpu::LayoutKind::Subgroup) {
     return !(anchorLayout.getEffectiveSgLayoutAsInt().empty() ||
              anchorLayout.getEffectiveSgDataAsInt().empty());
   }
@@ -573,7 +572,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
       instData = {instHeight, instWidth};
     }
 
-    if (layoutKind == LayoutKind::InstData)
+    if (layoutKind == xegpu::LayoutKind::InstData)
       prefetchLayout =
           LayoutInfo(xegpu::LayoutAttr::get(tdescTy.getContext(), instData));
     else
@@ -592,8 +591,8 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
     ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
   // The layout of the result must be present.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
+  LayoutInfo resLayoutInfo = results[0]->getValue();
+  if (!resLayoutInfo.isAssigned())
     return;
 
   VectorType sourceTy =
@@ -614,54 +613,42 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
              << dim << " ";
              llvm::dbgs() << "]\n");
 
-  auto resultLayoutAttr =
-      dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
-
-  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: resultLayoutAttr = "
-                    << resultLayoutAttr << "\n");
+  auto consumerLayoutAttr =
+      dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
 
+  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: consumerLayoutAttr = "
+                    << consumerLayoutAttr << "\n");
   // An dominant layout is for the result and represents the layout requirements
   // for the operation it is recorded to anchor layout or temporary layout it
   // must be honored for current op and may conflict with the layout propagated
   // from consumer op the conflict is resolved in later phase by converting the
   // dominant layout to the source layout
 
-  xegpu::DistributeLayoutAttr dominantLayout = xegpu::reductionLayoutSetupRule(
-      srcShape, reductionDims, resultLayoutAttr);
+  auto requiredResLayoutAttr = xegpu::reductionSetupResultLayout(
+      layoutKind, srcShape, consumerLayoutAttr, reductionDims);
+  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
+                    << requiredResLayoutAttr << "\n");
 
-  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: dominantLayout = "
-                    << dominantLayout << "\n");
-
-  if (layoutKind == LayoutKind::Lane) {
-    // only lane layout/data is considered
-    dominantLayout = dominantLayout.dropInstData();
-    dominantLayout = dominantLayout.dropSgLayoutAndData();
-  } else if (layoutKind == LayoutKind::InstData) {
-    dominantLayout = dominantLayout.dropSgLayoutAndData();
-  }
-
-  // record the dominant layout to the reduction op
-  xegpu::setTemporaryLayout(reduction->getResult(0), dominantLayout);
+  resLayoutInfo.set(requiredResLayoutAttr);
 
   // derive the source layout from the dominant layout and reduction dims
   auto srcLayoutAttr =
-      xegpu::inferReductionSourceLayout(dominantLayout, reductionDims);
-  // void set(const xegpu::DistributeLayoutAttr &layout) { storage = layout; }
+      xegpu::inferReductionSourceLayout(requiredResLayoutAttr, reductionDims);
 
   LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcLayoutAttr = "
                     << srcLayoutAttr << "\n");
 
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
   // Accumulator should have the same layout as the result.
-  propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
+  propagateIfChanged(operands[1], operands[1]->meet(resLayoutInfo));
 }
 
 void LayoutInfoPropagation::visitVectorBroadCastOp(
     vector::BroadcastOp broadcast, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
   // The layout of the result must be present.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
+  LayoutInfo resLayoutInfo = results[0]->getValue();
+  if (!resLayoutInfo.isAssigned())
     return;
 
   // Only consider vector to vector broadcasts for now.
@@ -681,12 +668,12 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
                             "with unit-dim, mixed scenario is not supported!");
 
   auto resultLayoutAttr =
-      dyn_cast<xegpu::DistributeLayoutAttr>(resultLayout.get());
+      dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
 
-  xegpu::DistributeLayoutAttr srcLayout =
+  xegpu::DistributeLayoutAttr srcLayoutAttr =
       xegpu::inferBroadcastSourceLayout(resultLayoutAttr, resShape, srcShape);
 
-  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayout)));
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
   return;
 }
 
@@ -694,23 +681,18 @@ void LayoutInfoPropagation::visitShapeCastOp(
     vector::ShapeCastOp shapeCast, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
   // The layout of the result must be present.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
-    return;
-  VectorType sourceTy = shapeCast.getSourceVectorType();
-  VectorType resultTy = shapeCast.getResultVectorType();
-  // Shape cast layout propagation only supports 1D -> 2D shape casts.
-  // TODO: Support kD -> nD shape casts (k < n, n >= 2) where expanded dims are
-  // unit dimensions and non-unit dims match.
-  if (sourceTy.getRank() != 1 || resultTy.getRank() != 2) {
-    shapeCast.emitWarning("Expecting shape cast to be 1D -> 2D.");
+  LayoutInfo resLayoutInfo = results[0]->getValue();
+  if (!resLayoutInfo.isAssigned())
     return;
-  }
-  int64_t slicedDim = resultTy.getShape()[0] == 1 ? 0 : 1;
-  xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
-      shapeCast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
-      DenseI64ArrayAttr::get(shapeCast->getContext(), {slicedDim}));
-  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
+  ArrayRef<int64_t> resShape = shapeCast.getResultVectorType().getShape();
+  ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();
+  auto resultLayoutAttr =
+      dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+
+  xegpu::DistributeLayoutAttr srcLayoutAttr =
+      xegpu::inferShapeCastSourceLayout(resultLayoutAttr, resShape, srcShape);
+
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
 }
 
 /// Propagate the layout of the result tensor to the source tensor descriptor
@@ -779,7 +761,7 @@ void LayoutInfoPropagation::visitDpasOp(
     SmallVector<int> instDataA = {maxALen, subgroupSize};
     SmallVector<int> instDataB = {subgroupSize, maxBLen};
 
-    if (layoutKind == LayoutKind::InstData) {
+    if (layoutKind == xegpu::LayoutKind::InstData) {
       dpasALayout =
           LayoutInfo(xegpu::LayoutAttr::get(dpas.getContext(), instDataA));
       dpasBLayout =
@@ -793,7 +775,7 @@ void LayoutInfoPropagation::visitDpasOp(
 
     if (operands.size() > 2) {
       VectorType cTy = dpas.getAccType();
-      if (layoutKind == LayoutKind::InstData) {
+      if (layoutKind == xegpu::LayoutKind::InstData) {
         const unsigned dataCLen = bTy.getShape().back();
         auto supportedCLen =
             uArchInstruction->getSupportedN(bTy.getElementType());
@@ -863,7 +845,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
       instData = {instHeight, instWidth};
     }
 
-    if (layoutKind == LayoutKind::InstData)
+    if (layoutKind == xegpu::LayoutKind::InstData)
       storeLayout =
           LayoutInfo(xegpu::LayoutAttr::get(dataTy.getContext(), instData));
     else
@@ -930,66 +912,29 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
     vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
   // Need the layout of bitcast result to propagate to the operands.
-  LayoutInfo resultLayout = results[0]->getValue();
-  if (!resultLayout.isAssigned())
+  LayoutInfo resLayoutInfo = results[0]->getValue();
+  if (!resLayoutInfo.isAssigned())
     return;
   int inElemTyBitWidth =
       bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
   int outElemTyBitWidth =
       bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
-  // If the element bit widths are the same, then the layout does not change.
-  if (inElemTyBitWidth == outElemTyBitWidth) {
-    propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
-    return;
-  }
-  // Check if the result layout is valid. i.e. result vector can be distributed.
-  auto resultLaneLayout = resultLayout.getLaneLayout();
-  auto resultLaneData = resultLayout.getLaneData();
-  if (failed(xegpu::getDistributedVectorType(
-          bitcast.getResultVectorType(),
-          xegpu::LayoutAttr::get(bitcast->getContext(), resultLaneLayout,
-                                 resultLaneData)))) {
-    bitcast.emitWarning(
-        "Result vector type can not be evenly distributed across lanes.");
-    return;
-  }
-  int64_t rank = bitcast.getSourceVectorType().getRank();
-  // Bitcast is a `narrowing` if the input element type bit width larger than
-  // the output element type bit width. eg. f32 -> f16 is a narrowing bitcast.
-  bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
-  int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
-                                 : outElemTyBitWidth / inElemTyBitWidth;
-  SmallVector<int> sourceLaneLayout =
-      resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
-  SmallVector<int> outData = resultLayout.getLaneData();
-
-  // TODO: Currently we assume that bitcasts does not require cross lane
-  // communication. So each lane must own the required number of elements to
-  // perform the bitcast locally without cross-lane communication.
-  int outInnerBitsPerLane = outData[rank - 1] * outElemTyBitWidth;
-  if (outInnerBitsPerLane < inElemTyBitWidth) {
-    bitcast.emitWarning(
-        "Narrowing bitcast with cross lane communication is not supported.");
-    return;
-  }
-  // Check if each lane owns a single element in all dimensions except the
-  // innermost dimension.
-  SmallVector<int> sourceLaneData(outData.begin(), outData.end() - 1);
-  if (llvm::any_of(sourceLaneData, [](int64_t d) { return d != 1; })) {
-    bitcast.emitWarning("Each lane must not own multiple elements in any "
-                        "dimension other than "
-                        "the innermost dimension.");
-    return;
-  }
-  // Decide lane data based on whether the bitcast is narrowing or widening.
-  int64_t innerMostLaneData = isNarrowing ? outData[rank - 1] / bitCastRatio
-                                          : outData[rank - 1] * bitCastRatio;
-  sourceLaneData.push_back(innerMostLaneData);
-
-  propagateIfChanged(
-      operands[0],
-      operands[0]->meet(LayoutInfo(xegpu::LayoutAttr::get(
-          bitcast->getContext(), sourceLaneLayout, sourceLaneData))));
+
+  ArrayRef<int64_t> srcShape = bitcast.getSourceVectorType().getShape();
+  auto consumerLayoutAttr =
+      dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+
+  auto requiredResLayoutAttr =
+      bitCastSetupResultLayout(layoutKind, srcShape, consumerLayoutAttr,
+                               outElemTyBitWidth, inElemTyBitWidth);
+
+  resLayoutInfo.set(requiredResLayoutAttr);
+
+  // derive the source layout from the dominant layout and reduction dims
+  auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
+      requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
+
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
 }
 
 /// Propagate the layout of the result to the tensor descriptor, mask and offset
@@ -1023,7 +968,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
         instData.push_back(chunkSize);
     }
 
-    if (layoutKind == LayoutKind::InstData)
+    if (layoutKind == xegpu::LayoutKind::InstData)
       loadLayout =
           LayoutInfo(xegpu::LayoutAttr::get(load.getContext(), instData));
     else
@@ -1085,7 +1030,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
     const int subgroupSize = uArch->getSubgroupSize();
 
-    if (layoutKind == LayoutKind::InstData) {
+    if (layoutKind == xegpu::LayoutKind::InstData) {
       SmallVector<int> instData{subgroupSize};
       if (auto chunkSize = storeScatter.getChunkSize().value_or(0);
           chunkSize > 1)
@@ -1135,7 +1080,8 @@ class RunLayoutInfoPropagation {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RunLayoutInfoPropagation)
 
-  RunLayoutInfoPropagation(Operation *op, LayoutKind layoutKind) : target(op) {
+  RunLayoutInfoPropagation(Operation *op, xegpu::LayoutKind layoutKind)
+      : target(op) {
     SymbolTableCollection symbolTable;
     loadBaselineAnalyses(solver);
     solver.load<LayoutInfoPropagation>(symbolTable, layoutKind);
@@ -1373,13 +1319,13 @@ struct XeGPUPropagateLayoutPass final
 } // namespace
 
 void XeGPUPropagateLayoutPass::runOnOperation() {
-  LayoutKind layoutKind;
+  xegpu::LayoutKind layoutKind;
   if (this->layoutKind == "lane") {
-    layoutKind = LayoutKind::Lane;
+    layoutKind = xegpu::LayoutKind::Lane;
   } else if (this->layoutKind == "inst") {
-    layoutKind = LayoutKind::InstData;
+    layoutKind = xegpu::LayoutKind::InstData;
   } else if (this->layoutKind == "subgroup") {
-    layoutKind = LayoutKind::Subgroup;
+    layoutKind = xegpu::LayoutKind::Subgroup;
   } else {
     getOperation()->emitError("Unsupported layout kind option: " +
                               this->layoutKind);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 9113f00ac39f0..8898be5f13dab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/AffineMap.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 8f4e2bb0451d8..6bafdc955c9d3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/DebugLog.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 45a002b63abd6..72c8ea3de9976 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <optional>
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index fb416b4c81f15..c21e22c966f3d 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -10,13 +10,13 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/ValueRange.h"
@@ -278,235 +278,26 @@ xegpu::dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs) {
   return out;
 }
 
-/// Attach layout attributes to all vector-type operands of operations within
-/// the given operation's region. Reports an error if any vector operand lacks
-/// a layout attribute.
-// bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-//   auto result = rootOp->walk([&](Operation *op) {
-//     for (OpOperand &operand : op->getOpOperands()) {
-//       // Layouts are needed for vector type only.
-//       if (!isa<VectorType>(operand.get().getType()))
-//         continue;
-//       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
-//       if (!layout) {
-//         op->emitError("Could not find layout attribute for operand ")
-//             << operand.getOperandNumber() << " of operation " <<
-//             op->getName();
-//         return WalkResult::interrupt();
-//       }
-//       xegpu::setDistributeLayoutAttr(operand, layout);
-//     }
-//     return WalkResult::advance();
-//   });
-//   return !result.wasInterrupted();
-// }
-
-// Prerequisite for Layout Recovery
-// It relies on the following invariant:
-// 1. there is no layout conflict between different uses of the same definition.
-// 2. each definition has a well-defined layout requirement at its use point.
-//     - Every definition must have at least one use that appears after it in
-//     topological order.
-//     - If a definition has no such use (e.g., a loop result or region output),
-//     an explicit convert_layout operation is inserted to create a use.
-//     - Only the result of convert_layout is permitted to have no subsequent
-//     use.
-
-// The recover proceeds by scanning the operation in reverse topological order
-// as follows:
-//    For regular operations: First the result layouts are propagated from uses.
-//      Then the result layouts are propagated to uses (operands).
-//
-//    For region operations (e.g., loops):
-//       - When backward propagation reaches a region op, it sets the layout of
-//       the region op’s results according to use points like regular ops.
-//       - Then, the result layouts (such as a loop output) are propagated to
-//       thiers corresponding operands in the yield.
-//       - When backward propagation reaches the first operation inside the
-//       region, the pass examines the region op’s initialization list,
-//       propagating from region arguments to the corresponding initialization
-//       operands.
-//       - This ensures that layout constraints are consistently propagated
-//       across region boundaries
-//        while preserving a single well-defined use for each definition at the
-//        region-op level.
-
-// Forward declarations
-static void walkRegionBackward(Region &region,
-                               llvm::function_ref<void(Operation *)> visit);
-static void propagateResultsToRegularOperands(Operation *op);
-static void propagateRegionResultsToYieldOperands(
-    mlir::RegionBranchTerminatorOpInterface yieldOp);
-static void propagateRegionArgsToInits(mlir::RegionBranchOpInterface *regionOp);
-
-// the inner function for recoverTemporaryLayouts is a recursive function
-// the input rootOp is the function operation, which is also a region op.
-// it recursivley process the region op in reverse topological order.
+// Attach layout attributes to all vector-type operands of operations within
+// the given operation's region. Reports an error if any vector operand lacks
+// a layout attribute.
 bool xegpu::recoverTemporaryLayouts(Operation *rootOp) {
-  rootOp->walk([&](func::FuncOp func) {
-    walkRegionBackward(func.getBody(), [&](Operation *op) {
-      if (auto regionOp = dyn_cast<mlir::RegionBranchOpInterface>(op)) {
-        // hit the region op after visiting inside region
-        propagateRegionArgsToInits(&regionOp);
-      } else if (auto yieldOp =
-                     dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
-        // yield op inside region op
-        propagateRegionResultsToYieldOperands(yieldOp);
-      } else {
-        // if the op is regular op, calling propagateResultsToRegularOperands
-        propagateResultsToRegularOperands(op);
-      }
-    });
-  });
-}
-
-static void walkRegionBackward(Region &region,
-                               llvm::function_ref<void(Operation *)> visit) {
-  // blocks: back -> front
-  for (Block &block : llvm::reverse(region)) {
-    // ops: back -> front, early-inc so visit() may erase current op safely
-    for (Operation &op : llvm::reverse(block)) {
-      // make sure we first visit inside the region op (so yield op first)
-      // and then move to region op itself
-      for (Region &nested : llvm::reverse(op.getRegions()))
-        walkRegionBackward(nested, visit);
-
-      visit(&op);
-    }
-  }
-}
-
-// For regular operations: First the result layouts are propagated from uses.
-// Then the result layouts are propagated to uses (operands).
-static void propagateResultsToRegularOperands(Operation *op) {
-  OpResult result = op->getOpResults()[0];
-  auto resLayout = xegpu::getDistributeLayoutAttr(result);
-  assert(resLayout &&
-         "result layout must be defined before propagating to uses");
-
-  // if op is reduction op, call inferReductionSourceLayout
-  if (auto reduceOp = dyn_cast<vector::MultiDimReductionOp>(op)) {
-    SmallVector<int64_t> reduceDims(reduceOp.getReductionDims().begin(),
-                                    reduceOp.getReductionDims().end());
-    auto srcLayout = xegpu::inferReductionSourceLayout(resLayout, reduceDims);
-    // set the layout to the operand
-    xegpu::setTemporaryLayout(reduceOp->getOpOperand(0), srcLayout);
-    xegpu::setTemporaryLayout(reduceOp->getOpOperand(1), resLayout);
-    return;
-  }
-
-  // if op is broadcast op, call inferBroadcastSourceLayout
-  if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
-    ArrayRef<int64_t> resShape =
-        llvm::cast<VectorType>(broadcastOp.getResult().getType()).getShape();
-    ArrayRef<int64_t> srcShape =
-        llvm::cast<VectorType>(broadcastOp.getSource().getType()).getShape();
-    auto srcLayout =
-        xegpu::inferBroadcastSourceLayout(resLayout, resShape, srcShape);
-    // set the layout to the operand
-    xegpu::setTemporaryLayout(broadcastOp->getOpOperand(0), srcLayout);
-    return;
-  }
-
-  // if op is bitcast op, call inferBitCastSourceLayout
-  if (auto bitcastOp = dyn_cast<vector::BitCastOp>(op)) {
-    int resElemTyBitWidth =
-        llvm::cast<VectorType>(bitcastOp.getResult().getType())
-            .getElementTypeBitWidth();
-    int srcElemTyBitWidth =
-        llvm::cast<VectorType>(bitcastOp.getSource().getType())
-            .getElementTypeBitWidth();
-    auto srcLayout = xegpu::inferBitCastSourceLayout(
-        resLayout, resElemTyBitWidth, srcElemTyBitWidth);
-    // set the layout to the operand
-    xegpu::setTemporaryLayout(bitcastOp->getOpOperand(0), srcLayout);
-    return;
-  }
-
-  // if op is shape_cast op, call inferShapecastSourceLayout
-  if (auto shapeCastOp = dyn_cast<vector::ShapeCastOp>(op)) {
-    ArrayRef<int64_t> resShape =
-        llvm::cast<VectorType>(shapeCastOp.getResult().getType()).getShape();
-    ArrayRef<int64_t> srcShape =
-        llvm::cast<VectorType>(shapeCastOp.getSource().getType()).getShape();
-    auto srcLayout =
-        xegpu::inferShapecastSourceLayout(resLayout, resShape, srcShape);
-    // set the layout to the operand
-    xegpu::setTemporaryLayout(shapeCastOp->getOpOperand(0), srcLayout);
-    return;
-  }
-
-  // if op is a anchor op, no need to do anything
-  if (isa<xegpu::AnchorLayoutInterface>(op)) {
-    return;
-  }
-
-  // for other regular ops, propagate the result layout to all vector operands
-  for (OpOperand &opr : op->getOpOperands()) {
-    // Layouts are needed for vector type only.
-    if (!isa<VectorType>(opr.get().getType()))
-      continue;
-    xegpu::setTemporaryLayout(opr, resLayout);
-  }
-}
-
-static void propagateRegionResultsToYieldOperands(
-    mlir::RegionBranchTerminatorOpInterface yieldOp) {
-  llvm::SmallVector<mlir::RegionSuccessor> successors;
-  llvm::SmallVector<mlir::Attribute> operands(yieldOp->getNumOperands(),
-                                              nullptr);
-  yieldOp.getSuccessorRegions(operands, successors);
-
-  for (mlir::RegionSuccessor &successor : successors) {
-    // find out the successor which is the parent operation
-    if (!successor.isParent())
-      continue;
-    // For parent successor, the region arguments of the current region
-    // correspond to the results of the parent operation
-    Operation *parentOp = yieldOp->getParentOp();
-    for (unsigned i = 0; i < yieldOp->getNumOperands(); ++i) {
-      Value parentResult = parentOp->getResult(i);
-      auto layout = xegpu::getDistributeLayoutAttr(parentResult);
-      assert(
-          layout &&
-          "region result layout must be defined before propagating to yield");
-      xegpu::setTemporaryLayout(yieldOp->getOpOperand(i), layout);
-    }
-  }
-}
-
-void propagateRegionArgsToInits(mlir::RegionBranchOpInterface *regionOp) {
-
-  // Get entry successors (regions that can be entered initially)
-  SmallVector<RegionSuccessor> successors;
-  regionOp->getEntrySuccessorRegions(/*operands=*/ArrayRef<Attribute>(),
-                                     successors);
-
-  // For each possible entry region, get the operands forwarded to it
-  for (RegionSuccessor &successor : successors) {
-    if (successor.isParent())
-      continue;
-    OperandRange initOperands = regionOp->getEntrySuccessorOperands(successor);
-    // initOperands are the initialization arguments for this successor
-    // iterate the region arguments
-    Region *successorRegion = successor.getSuccessor();
-    for (unsigned i = 0; i < successorRegion->getNumArguments(); ++i) {
-      Value regionArg = successorRegion->getArgument(i);
-      auto layout = xegpu::getDistributeLayoutAttr(regionArg);
-      if (!layout)
+  auto result = rootOp->walk([&](Operation *op) {
+    for (OpOperand &operand : op->getOpOperands()) {
+      // Layouts are needed for vector type only.
+      if (!isa<VectorType>(operand.get().getType()))
         continue;
-      // The initOperands are a subset of the parent operation's operands
-      // We need to find which operand index this corresponds to
-      Value initOperand = initOperands[i];
-      // Find the operand index in the parent operation
-      for (OpOperand &operand : (*regionOp)->getOpOperands()) {
-        if (operand.get() == initOperand) {
-          xegpu::setTemporaryLayout(operand, layout);
-          break;
-        }
+      auto layout = xegpu::getDistributeLayoutAttr(operand.get());
+      if (!layout) {
+        op->emitError("Could not find layout attribute for operand ")
+            << operand.getOperandNumber() << " of operation " << op->getName();
+        return WalkResult::interrupt();
       }
+      xegpu::setDistributeLayoutAttr(operand, layout);
     }
-  }
+    return WalkResult::advance();
+  });
+  return !result.wasInterrupted();
 }
 
 template <typename T, typename>
@@ -633,7 +424,7 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 /// Infers the source layout attribute for a shape cast operation given the
 /// result layout attribute, result shape, and source shape.
 xegpu::DistributeLayoutAttr
-xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
                                   ArrayRef<int64_t> resShape,
                                   ArrayRef<int64_t> srcShape) {
 
@@ -760,10 +551,9 @@ xegpu::inferShapecastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 ///    reduced result stays with the same subgroup distribution as expected by
 ///    the consumer.
 
-xegpu::SliceAttr
-xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
-                                SmallVector<int64_t> reductionDims,
-                                DistributeLayoutAttr consumerLayout) {
+xegpu::SliceAttr xegpu::reductionSetupResultLayout(
+    xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
+    DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims) {
 
   xegpu::SliceAttr consumerSliceLayout =
       dyn_cast<xegpu::SliceAttr>(consumerLayout);
@@ -782,15 +572,18 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
 
   SmallVector<int64_t> sgLayout(srcShapeSize);
   SmallVector<int64_t> sgData(srcShapeSize);
+  SmallVector<int64_t> instData(srcShapeSize);
+  SmallVector<int64_t> laneLayout(srcShapeSize);
+  SmallVector<int64_t> laneData(srcShapeSize);
 
-  SmallVector<int64_t> instData(srcShapeSize, 1);
-  instData[srcShapeSize - 1] = subgroupSize;
-  instData[srcShapeSize - 2] =
+  SmallVector<int64_t> defaultInstData(srcShapeSize, 1);
+  defaultInstData[srcShapeSize - 1] = subgroupSize;
+  defaultInstData[srcShapeSize - 2] =
       vectorSize; // This will be adjusted based on actual data distribution
 
-  SmallVector<int64_t> laneLayout(srcShapeSize, 1);
-  laneLayout[srcShapeSize - 1] = subgroupSize;
-  SmallVector<int64_t> laneData(srcShapeSize, 1);
+  SmallVector<int64_t> defaultLaneLayout(srcShapeSize, 1);
+  defaultLaneLayout[srcShapeSize - 1] = subgroupSize;
+  SmallVector<int64_t> defaultLaneData(srcShapeSize, 1);
 
   // Strategy 1: Try to preserve the consumer's slice layout structure
   // If the consumer already expects a slice layout with the same reduction
@@ -811,7 +604,7 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
     if (!consumerSliceLayout.getDims().asArrayRef().equals(reductionDims))
       return false;
     xegpu::DistributeLayoutAttr parentLayout = consumerSliceLayout.getParent();
-    if (!parentLayout.getRank() == srcShapeSize)
+    if (parentLayout.getRank() != srcShapeSize)
       return false;
 
     SmallVector<int64_t> parentSgLayout =
@@ -842,54 +635,103 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
     SmallVector<int64_t> parentLaneData =
         consumerSliceLayout.getEffectiveLaneDataAsInt();
 
-    for (int i = 0; i < srcShapeSize; i++) {
-      sgLayout[i] = parentSgLayout[i];
-      sgData[i] = srcShape[i] / sgLayout[i];
-      instData[i] = std::min(instData[i], sgData[i]);
-      laneLayout[i] = parentLaneLayout[i];
-      laneData[i] = parentLaneData[i];
+    switch (layoutKind) {
+    case xegpu::LayoutKind::Subgroup:
+      sgLayout = parentSgLayout;
+      for (int i = 0; i < srcShapeSize; i++)
+        if (llvm::is_contained(reductionDims, i))
+          sgData[i] = srcShape[i] / sgLayout[i];
+        else
+          sgData[i] = parentSgLayout[i];
+      break;
+    case xegpu::LayoutKind::InstData:
+      for (int i = 0; i < srcShapeSize; i++)
+        instData[i] = std::min(defaultInstData[i], srcShape[i]);
+      break;
+    case xegpu::LayoutKind::Lane:
+      laneLayout = parentLaneLayout;
+      for (int i = 0; i < srcShapeSize; i++) {
+        assert((srcShape[i] % laneLayout[i] == 0) &&
+               "source shape not divisible by lane layout");
+        laneData[i] = srcShape[i] / laneLayout[i];
+      }
+      break;
+    default:
+      llvm_unreachable("unsupported layout kind");
     }
-
   } else {
     // Strategy 2: Construct a new layout aligned with consumer's sg_layout for
     // the result (non-reduction dims) then distribute remaining subgroups
     // across reduced dimensions
 
-    SmallVector<int64_t> cplSgLayout =
+    SmallVector<int64_t> consumerSgLayout =
         consumerLayout.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> cplSgData = consumerLayout.getEffectiveSgDataAsInt();
-    SmallVector<int64_t> cplInstData =
-        consumerLayout.getEffectiveInstDataAsInt();
+    SmallVector<int64_t> consumerLaneLayout =
+        consumerLayout.getEffectiveLaneLayoutAsInt();
     int remainingSgCount = workgroupSize;
-    SmallVector<int64_t> remainingDims;
-    int cplId = cplSgLayout.size() - 1;
-    // For non-reduction dimensions, try to match consumer's sg_layout
-    // This ensures the result after reduction has the expected distribution
-    for (int i = srcShapeSize - 1; i >= 0; i--) {
-      if (!llvm::is_contained(reductionDims, i) && cplId >= 0) {
-        assert((srcShape[i] % cplSgLayout[cplId] == 0) &&
-               "source shape not divisible by consumer sg_layout");
-        sgLayout[i] = cplSgLayout[cplId];
-        sgData[i] = srcShape[i] / sgLayout[i];
-        instData[i] = std::min(cplInstData[cplId], sgData[i]);
-        remainingSgCount /= sgLayout[i];
-        cplId--;
-      }
-    }
-
-    // Second pass: Distribute remaining subgroups across unhandled dimensions
-    // This handles reduction dimensions and dimensions that didn't align with
-    // consumer
-    for (int i = srcShapeSize - 1; i >= 0; i--) {
-      if (llvm::is_contained(reductionDims, i)) {
-        sgLayout[i] = std::min((srcShape[i] / laneLayout[i]),
-                               static_cast<int64_t>(remainingSgCount));
-        sgData[i] = srcShape[i] / sgLayout[i];
-        instData[i] = std::min(instData[i], sgData[i]);
-        remainingSgCount /= sgLayout[i];
-      }
+    int remainingLaneCount = subgroupSize;
+    int consumerSgId, consumerLaneId;
+
+    switch (layoutKind) {
+    case xegpu::LayoutKind::Subgroup:
+      consumerSgId = consumerSgLayout.size() - 1;
+      // For non-reduction dimensions, try to match consumer's sg_layout
+      // This ensures the result after reduction has the expected distribution
+      for (int i = srcShapeSize - 1; i >= 0; i--)
+        if (!llvm::is_contained(reductionDims, i) && consumerSgId >= 0) {
+          sgLayout[i] = consumerSgLayout[consumerSgId];
+          assert((srcShape[i] % sgLayout[i] == 0) &&
+                 "source shape not divisible by consumer sg_layout");
+          sgData[i] = srcShape[i] / sgLayout[i];
+          remainingSgCount /= sgLayout[i];
+          consumerSgId--;
+        }
+      // Second pass: Distribute remaining subgroups across unhandled dimensions
+      // This handles reduction dimensions that don't necessarily align with
+      // consumer
+      for (int i = srcShapeSize - 1; i >= 0; i--)
+        if (llvm::is_contained(reductionDims, i)) {
+          sgLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
+                                 static_cast<int64_t>(remainingSgCount));
+          assert((srcShape[i] % sgLayout[i] == 0) &&
+                 "source shape not divisible by consumer sg_layout");
+          sgData[i] = srcShape[i] / sgLayout[i];
+          remainingSgCount /= sgLayout[i];
+        }
+      assert(remainingSgCount == 1 &&
+             "not all subgroups have been distributed");
+      break;
+    case xegpu::LayoutKind::InstData:
+      for (int i = 0; i < srcShapeSize; i++)
+        instData[i] = std::min(defaultInstData[i], srcShape[i]);
+      break;
+    case xegpu::LayoutKind::Lane:
+      consumerLaneId = consumerLaneLayout.size() - 1;
+      // For non-reduction dimensions, try to match consumer's lane_layout
+      // This ensures the result after reduction has the expected distribution
+      for (int i = 0; i < srcShapeSize; i++)
+        if (!llvm::is_contained(reductionDims, i) && consumerLaneId >= 0) {
+          laneLayout[i] = consumerLaneLayout[consumerLaneId];
+          assert((srcShape[i] % laneLayout[i] == 0) &&
+                 "source shape not divisible by consumer lane_layout");
+          laneData[i] = srcShape[i] / laneLayout[i];
+          remainingLaneCount /= laneLayout[i];
+          consumerLaneId--;
+        }
+      for (int i = 0; i < srcShapeSize; i++)
+        if (llvm::is_contained(reductionDims, i)) {
+          laneLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
+                                   static_cast<int64_t>(remainingLaneCount));
+          assert((srcShape[i] % laneLayout[i] == 0) &&
+                 "source shape not divisible by consumer lane_layout");
+          laneData[i] = srcShape[i] / laneLayout[i];
+          remainingLaneCount /= laneLayout[i];
+        }
+      assert(remainingLaneCount == 1 && "not all lanes have been distributed");
+      break;
+    default:
+      llvm_unreachable("unsupported layout kind");
     }
-    assert(remainingSgCount == 1 && "not all subgroups have been distributed");
   }
 
   SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
@@ -897,37 +739,76 @@ xegpu::reductionLayoutSetupRule(ArrayRef<int64_t> srcShape,
   SmallVector<int32_t> instData32(instData.begin(), instData.end());
   SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
   SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
+
   proposedSrcLayout = xegpu::LayoutAttr::get(
       context, DenseI32ArrayAttr::get(context, sgLayout32),
       DenseI32ArrayAttr::get(context, sgData32),
       DenseI32ArrayAttr::get(context, instData32),
       DenseI32ArrayAttr::get(context, laneLayout32),
       DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
-  xegpu::SliceAttr reductionResLayout =
+
+  xegpu::SliceAttr resLayout =
       xegpu::SliceAttr::get(context, proposedSrcLayout,
                             DenseI64ArrayAttr::get(context, reductionDims));
-  return reductionResLayout;
+  return resLayout;
 }
 
 xegpu::DistributeLayoutAttr
-xegpu::bitCastLayoutSetupRule(xegpu::DistributeLayoutAttr resLayout,
-                              int resElemTyBitWidth, int srcElemTyBitWidth) {
-  SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
-  SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
-  SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
-  size_t dim = sgData.size() - 1;
+xegpu::bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
+                                ArrayRef<int64_t> srcShape,
+                                DistributeLayoutAttr consumerLayout,
+                                int resElemTyBitWidth, int srcElemTyBitWidth) {
+  SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
+  size_t dim = srcShape.size() - 1;
   int64_t sgDataValue, instDataValue, laneDataValue;
-  if (srcElemTyBitWidth < resElemTyBitWidth) {
-    int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
-    sgDataValue = (dim < sgData.size()) ? sgData[dim] * bitWidthRatio : -1;
-    instDataValue =
-        (dim < instData.size()) ? instData[dim] * bitWidthRatio : -1;
-    laneDataValue =
-        (dim < laneData.size()) ? laneData[dim] * bitWidthRatio : -1;
+  const int subgroupSize = 16; // assuming 16 lanes per subgroup
+
+  if (srcElemTyBitWidth > resElemTyBitWidth) {
+    // When casting to a smaller bitwidth, multiply the result layout
+    // accordingly to ensure it can be divided by the ratio back to the
+    // source layout.
+    int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+    int innermostDimLaneLayout = subgroupSize;
+    switch (layoutKind) {
+    case xegpu::LayoutKind::Subgroup:
+      assert(sgData.size() == srcShape.size() &&
+             "sgData must be available for all dimensions");
+      sgDataValue = sgData[dim];
+      break;
+    case xegpu::LayoutKind::InstData:
+      assert(instData.size() == srcShape.size() &&
+             "instData must be available for all dimensions");
+      instDataValue = instData[dim];
+      // adjust instDataValue so it still fits within an instruction after
+      // dividing by bitWidthRatio
+      while ((instDataValue <= srcShape[dim]) &&
+             (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
+        instDataValue *= 2;
+      assert(srcShape[dim] % instDataValue == 0 &&
+             "srcShape, instData, and lanelayout for innermost must be 2^n !");
+      break;
+    case xegpu::LayoutKind::Lane:
+      assert(laneData.size() == srcShape.size() &&
+             "laneData must be available for all dimensions");
+      laneDataValue = laneData[dim];
+      while ((laneDataValue <= srcShape[dim]) &&
+             (laneDataValue % bitWidthRatio != 0))
+        laneDataValue *= 2;
+      break;
+    default:
+      llvm_unreachable("unsupported layout kind");
+    }
+  } else {
+    sgDataValue = sgData[dim];
+    instDataValue = instData[dim];
+    laneDataValue = laneData[dim];
   }
+
   // Now set only instData and laneData, preserving sgData
-  xegpu::DistributeLayoutAttr finalResLayout;
-  finalResLayout =
-      resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
-  return finalResLayout;
+  xegpu::DistributeLayoutAttr resLayout;
+  resLayout =
+      consumerLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
+  return resLayout;
 }
\ No newline at end of file

>From a0537d61e9c0aa5debc09008b3b34f630dbdb5dd Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 26 Jan 2026 19:50:50 +0000
Subject: [PATCH 09/35] reverse the slice attribute hint setting for
 store_matrix generated from cross-sg reduction

---
 .../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp   | 9 +--------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir    | 8 ++++----
 2 files changed, 5 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 72c8ea3de9976..88dc670e0ed6a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1489,15 +1489,8 @@ struct WgToSgMultiDimReductionOp
 
     SmallVector<OpFoldResult> storeOffsets2D = {rowOffsetStore, colOffset};
 
-    auto storeMatrixLayout = xegpu::SliceAttr::get(
-        rewriter.getContext(),
-        xegpu::LayoutAttr::get(rewriter.getContext(), /*sg_layout =*/nullptr,
-                               /*sg_data =*/nullptr,
-                               /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
-                               /*lane_data =*/nullptr, /*order =*/nullptr),
-        dyn_cast<xegpu::SliceAttr>(layout).getDims());
     xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
-                                 storeOffsets2D, /*layout=*/storeMatrixLayout);
+                                 storeOffsets2D, /*layout=*/nullptr);
 
     gpu::BarrierOp::create(rewriter, loc);
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 1fc2328d09046..9cb96775b4ee4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -674,7 +674,7 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
     // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL3]] : index
     // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [1]>}>: vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
     // CHECK-DAG: gpu.barrier
     // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
     // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
@@ -717,7 +717,7 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
     // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL4]] : index
     // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD1]], %[[C32:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [0]>}>: vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
     // CHECK-DAG: gpu.barrier
     // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
     // CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
@@ -766,7 +766,7 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
     // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
     // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C1:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [2, 3]>}>: vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
     // CHECK-DAG: gpu.barrier
     // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x4xf32>, index, index -> vector<16x1xf32>
     // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
@@ -810,7 +810,7 @@ gpu.module @test_distribution {
     // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
     // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
     // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C256:.*]] : index
-    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] <{layout = #xegpu.slice<#xegpu.layout<>, dims = [2, 3]>}>: vector<1x256xf32>, !xegpu.mem_desc<16x1024xf32>, index, index
+    // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x256xf32>, !xegpu.mem_desc<16x1024xf32>, index, index
     // CHECK-DAG: gpu.barrier
     // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<16x1024xf32>, index, index -> vector<16x256xf32>
     // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<256xf32>

>From cbdbf38a02d1562935427f87aa65f0e7f3b5f1f2 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 00:01:46 +0000
Subject: [PATCH 10/35] add store_matrix support

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    | 15 ++-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 55 +++++++----
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 92 +++++++++++++++----
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  2 +-
 4 files changed, 122 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 206b3d85df30b..39afa4d7c1b8c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
 
@@ -100,7 +101,8 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
 SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
                                      ArrayRef<int64_t> srcShape,
                                      DistributeLayoutAttr consumerLayout,
-                                     SmallVector<int64_t> reductionDims);
+                                     SmallVector<int64_t> reductionDims,
+                                     const uArch::uArch *uArch);
 
 /// Setup the result layout attribute for a bitcast operation based on element
 /// type bitwidths. This ensures the source layout can always be derived from
@@ -112,10 +114,15 @@ SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
 /// maintains the invariant that the source layout can be recovered by inverse
 /// scaling during layout inference.
 DistributeLayoutAttr
-bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
-                         ArrayRef<int64_t> srcShape,
+bitCastSetupResultLayout(LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
                          DistributeLayoutAttr consumerLayout,
-                         int resElemTyBitWidth, int srcElemTyBitWidth);
+                         int resElemTyBitWidth, int srcElemTyBitWidth,
+                         const uArch::uArch *uArch);
+
+// Setup the anchor layout attribute for a storeMatrix operation
+DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
+                                                  VectorType vectorTy,
+                                                  const uArch::uArch *uArch);
 
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 07cef91ff3291..e7aad5337d59a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -649,6 +649,10 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   if (!resLayoutInfo.isAssigned())
     return;
 
+  // debug print resLayoutInfo
+  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: resLayoutInfo = ";
+             resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
+
   VectorType sourceTy =
       llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
   SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
@@ -672,18 +676,28 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
 
   LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: consumerLayoutAttr = "
                     << consumerLayoutAttr << "\n");
-  // An dominant layout is for the result and represents the layout requirements
-  // for the operation it is recorded to anchor layout or temporary layout it
-  // must be honored for current op and may conflict with the layout propagated
-  // from consumer op the conflict is resolved in later phase by converting the
-  // dominant layout to the source layout
-
+  // The required result layout represents the layout requirements
+  // for the operation it is recorded to anchor layout or temporary layout.
+  // it must be honored for current op and may conflict with the layout
+  // propagated from consumer op, the conflict is resolved in later phase by
+  // converting the required result layout to the consumer layout
+  auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
   auto requiredResLayoutAttr = xegpu::reductionSetupResultLayout(
-      layoutKind, srcShape, consumerLayoutAttr, reductionDims);
+      layoutKind, srcShape, consumerLayoutAttr, reductionDims, uArch);
   LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
                     << requiredResLayoutAttr << "\n");
 
-  resLayoutInfo.set(requiredResLayoutAttr);
+  // resLayoutInfo.set(requiredResLayoutAttr);
+  xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
+
+  // debug print resLayoutInfo
+  LLVM_DEBUG(
+      DBGS() << "visitVectorMultiReductionOp: after change resLayoutInfo = ";
+      resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
+
+  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: after change "
+                       "results[0]->getValue() = ";
+             results[0]->getValue().print(llvm::dbgs()); llvm::dbgs() << "\n");
 
   // derive the source layout from the dominant layout and reduction dims
   auto srcLayoutAttr =
@@ -694,7 +708,8 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
 
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
   // Accumulator should have the same layout as the result.
-  propagateIfChanged(operands[1], operands[1]->meet(resLayoutInfo));
+  propagateIfChanged(operands[1],
+                     operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
 }
 
 void LayoutInfoPropagation::visitVectorBroadCastOp(
@@ -1102,12 +1117,12 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   ArrayRef<int64_t> srcShape = bitcast.getSourceVectorType().getShape();
   auto consumerLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
-
+  auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
   auto requiredResLayoutAttr =
       bitCastSetupResultLayout(layoutKind, srcShape, consumerLayoutAttr,
-                               outElemTyBitWidth, inElemTyBitWidth);
+                               outElemTyBitWidth, inElemTyBitWidth, uArch);
 
-  resLayoutInfo.set(requiredResLayoutAttr);
+  xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
 
   // derive the source layout from the dominant layout and reduction dims
   auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
@@ -1327,12 +1342,10 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
     VectorType payloadTy = llvm::cast<VectorType>(operand.getType());
     assert(payloadTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
     auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
-    SmallVector<int> instData = {1, uArch->getSubgroupSize()};
-    if (layoutKind == xegpu::LayoutKind::InstData)
-      layout = LayoutInfo(
-          xegpu::LayoutAttr::get(storeMatrix.getContext(), instData));
-    else
-      layout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
+    auto requiredAnchorLayoutAttr =
+        storeMatrixSetupAnchorLayout(layoutKind, payloadTy, uArch);
+    storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
+    layout = LayoutInfo(requiredAnchorLayoutAttr);
   }
 
   propagateIfChanged(operands[index], operands[index]->meet(layout));
@@ -1612,6 +1625,12 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
     LayoutInfo layout = analysis.getLayoutInfo(val);
     if (!layout.isAssigned())
       return {};
+    if (auto opResult = dyn_cast<OpResult>(val)) {
+      xegpu::DistributeLayoutAttr requiredResLayoutAttr =
+          xegpu::getTemporaryLayout(opResult);
+      if (requiredResLayoutAttr != nullptr)
+        return requiredResLayoutAttr;
+    }
     xegpu::DistributeLayoutAttr layoutAttr =
         cast<xegpu::DistributeLayoutAttr>(layout.get());
     if (layout.isSliceLayout())
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index c21e22c966f3d..d42f0d30ce8c0 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -553,16 +553,12 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 
 xegpu::SliceAttr xegpu::reductionSetupResultLayout(
     xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
-    DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims) {
+    DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
+    const xegpu::uArch::uArch *uArch) {
 
   xegpu::SliceAttr consumerSliceLayout =
       dyn_cast<xegpu::SliceAttr>(consumerLayout);
 
-  // Hardware constraints (TODO: these should ideally be queried from device
-  // capabilities)
-  const int workgroupSize = 16; // Total number of subgroups in a workgroup
-  const int subgroupSize = 16;  // Number of SIMD lanes per subgroup
-  const int vectorSize = 8;     // Elements processed per vector instruction
   int srcShapeSize = srcShape.size();
   xegpu::DistributeLayoutAttr proposedSrcLayout;
   auto context = consumerLayout.getContext();
@@ -576,6 +572,23 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
   SmallVector<int64_t> laneLayout(srcShapeSize);
   SmallVector<int64_t> laneData(srcShapeSize);
 
+  // recover workgroup and subgroup size from consumer layout
+  DistributeLayoutAttr origPlainLayout;
+  if (consumerSliceLayout) {
+    origPlainLayout = consumerSliceLayout.flatten().getParent();
+  } else {
+    origPlainLayout = consumerLayout;
+  }
+
+  const int workgroupSize =
+      std::accumulate(origPlainLayout.getEffectiveSgLayoutAsInt().begin(),
+                      origPlainLayout.getEffectiveSgLayoutAsInt().end(), 1,
+                      std::multiplies<int64_t>());
+
+  const int subgroupSize = uArch->getSubgroupSize();
+
+  const int vectorSize = 16; // vector size from SPRIV vector restriction
+
   SmallVector<int64_t> defaultInstData(srcShapeSize, 1);
   defaultInstData[srcShapeSize - 1] = subgroupSize;
   defaultInstData[srcShapeSize - 2] =
@@ -702,8 +715,11 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
              "not all subgroups have been distributed");
       break;
     case xegpu::LayoutKind::InstData:
-      for (int i = 0; i < srcShapeSize; i++)
+      for (int i = 0; i < srcShapeSize; i++) {
         instData[i] = std::min(defaultInstData[i], srcShape[i]);
+        llvm::dbgs() << "MultiReductionOp: Strategy 2: instData [" << i
+                     << "] = " << instData[i] << "\n";
+      }
       break;
     case xegpu::LayoutKind::Lane:
       consumerLaneId = consumerLaneLayout.size() - 1;
@@ -740,12 +756,24 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
   SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
   SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
 
-  proposedSrcLayout = xegpu::LayoutAttr::get(
-      context, DenseI32ArrayAttr::get(context, sgLayout32),
-      DenseI32ArrayAttr::get(context, sgData32),
-      DenseI32ArrayAttr::get(context, instData32),
-      DenseI32ArrayAttr::get(context, laneLayout32),
-      DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
+  switch (layoutKind) {
+  case xegpu::LayoutKind::Subgroup:
+    proposedSrcLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, sgLayout32),
+        DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+    break;
+  case xegpu::LayoutKind::InstData:
+    proposedSrcLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, instData32));
+    break;
+  case xegpu::LayoutKind::Lane:
+    proposedSrcLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, laneLayout32),
+        DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
+    break;
+  default:
+    llvm_unreachable("unsupported layout kind");
+  }
 
   xegpu::SliceAttr resLayout =
       xegpu::SliceAttr::get(context, proposedSrcLayout,
@@ -753,17 +781,17 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
   return resLayout;
 }
 
-xegpu::DistributeLayoutAttr
-xegpu::bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
-                                ArrayRef<int64_t> srcShape,
-                                DistributeLayoutAttr consumerLayout,
-                                int resElemTyBitWidth, int srcElemTyBitWidth) {
+xegpu::DistributeLayoutAttr xegpu::bitCastSetupResultLayout(
+    xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
+    DistributeLayoutAttr consumerLayout, int resElemTyBitWidth,
+    int srcElemTyBitWidth, const xegpu::uArch::uArch *uArch) {
   SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
   SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
   SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
   size_t dim = srcShape.size() - 1;
   int64_t sgDataValue, instDataValue, laneDataValue;
-  const int subgroupSize = 16; // assuming 16 lanes per subgroup
+
+  const int subgroupSize = uArch->getSubgroupSize();
 
   if (srcElemTyBitWidth > resElemTyBitWidth) {
     // When casting to a smaller bitwidth, multiply the result layout
@@ -811,4 +839,30 @@ xegpu::bitCastSetupResultLayout(xegpu::LayoutKind layoutKind,
   resLayout =
       consumerLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
   return resLayout;
+}
+
+xegpu::DistributeLayoutAttr
+xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
+                                    VectorType vectorTy,
+                                    const xegpu::uArch::uArch *uArch) {
+
+  xegpu::DistributeLayoutAttr requiredLayout;
+  SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+  switch (layoutKind) {
+  case xegpu::LayoutKind::Subgroup:
+    assert(false &&
+           "subgroup layout assignment not supported yet for storeMatrix.");
+    break;
+  case xegpu::LayoutKind::InstData:
+    requiredLayout = xegpu::LayoutAttr::get(vectorTy.getContext(), instData);
+    break;
+  case xegpu::LayoutKind::Lane:
+    requiredLayout = xegpu::LayoutAttr::get(
+        vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, 1});
+
+    break;
+  default:
+    llvm_unreachable("unsupported layout kind");
+  }
+  return requiredLayout;
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 9f309917a7d60..acfd2e34c805c 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -702,7 +702,7 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
 // -----
 gpu.module @test {
 // CHECK-LABEL: func.func @store_matrix(
-// CHECK:         %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} dense<0.000000e+00> : vector<16x16xf16>
+// CHECK:         %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
 // CHECK-NEXT:     xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
 
 func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {

>From 6c4bdb442c9fd8ff2b180692556bec097ded1893 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 04:33:00 +0000
Subject: [PATCH 11/35] add more storescatter layout utilits

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    |  20 +-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 186 +++++++++---------
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  |  98 ++++++++-
 3 files changed, 197 insertions(+), 107 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 39afa4d7c1b8c..fe23510845ae3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -99,7 +99,7 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
 /// The SliceAttr for the result is then created based on the derived source
 /// layout and the specified reduction dimensions.
 SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
-                                     ArrayRef<int64_t> srcShape,
+                                     VectorType srcVectorTy,
                                      DistributeLayoutAttr consumerLayout,
                                      SmallVector<int64_t> reductionDims,
                                      const uArch::uArch *uArch);
@@ -113,17 +113,25 @@ SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
 /// (inst_data, lane_data) are scaled up by the bitwidth ratio. This
 /// maintains the invariant that the source layout can be recovered by inverse
 /// scaling during layout inference.
-DistributeLayoutAttr
-bitCastSetupResultLayout(LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
-                         DistributeLayoutAttr consumerLayout,
-                         int resElemTyBitWidth, int srcElemTyBitWidth,
-                         const uArch::uArch *uArch);
+DistributeLayoutAttr bitCastSetupResultLayout(
+    LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
+    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
 // Setup the anchor layout attribute for a storeMatrix operation
 DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
                                                   VectorType vectorTy,
                                                   const uArch::uArch *uArch);
+xegpu::DistributeLayoutAttr
+xegpu::loadMatrixSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+                                   const uArch::uArch *uArch);
+
+xegpu::DistributeLayoutAttr
+xegpu::loadGatherSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+                                   const uArch::uArch *uArch);
 
+xegpu::DistributeLayoutAttr
+xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+                                     const uArch::uArch *uArch);
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index e7aad5337d59a..6f356f490726b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -419,6 +419,10 @@ class LayoutInfoPropagation
                         ArrayRef<LayoutInfoLattice *> operands,
                         ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
+
   void visitStoreMatrixOp(xegpu::StoreMatrixOp store,
                           ArrayRef<LayoutInfoLattice *> operands,
                           ArrayRef<const LayoutInfoLattice *> results);
@@ -496,6 +500,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
         visitShapeCastOp(shapeCastOp, operands, results);
       })
+      .Case<xegpu::LoadMatrixOp>([&](auto loadMatrixOp) {
+        visitLoadMatrixOp(loadMatrixOp, operands, results);
+      })
       .Case<xegpu::StoreMatrixOp>([&](auto storeMatrixOp) {
         visitStoreMatrixOp(storeMatrixOp, operands, results);
       })
@@ -658,11 +665,9 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
                                      reduction.getReductionDims().end());
 
-  auto srcShape = sourceTy.getShape();
-
   LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcShape = [";
              for (auto dim
-                  : srcShape) llvm::dbgs()
+                  : sourceTy.getShape()) llvm::dbgs()
              << dim << " ";
              llvm::dbgs() << "]\n");
   LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: reductionDims = [";
@@ -683,7 +688,7 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   // converting the required result layout to the consumer layout
   auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
   auto requiredResLayoutAttr = xegpu::reductionSetupResultLayout(
-      layoutKind, srcShape, consumerLayoutAttr, reductionDims, uArch);
+      layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
   LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
                     << requiredResLayoutAttr << "\n");
 
@@ -1109,21 +1114,21 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   LayoutInfo resLayoutInfo = results[0]->getValue();
   if (!resLayoutInfo.isAssigned())
     return;
-  int inElemTyBitWidth =
-      bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
-  int outElemTyBitWidth =
-      bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
 
-  ArrayRef<int64_t> srcShape = bitcast.getSourceVectorType().getShape();
+  auto srcVecType = bitcast.getSourceVectorType();
+  auto resVecType = bitcast.getResultVectorType();
+
   auto consumerLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
   auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
-  auto requiredResLayoutAttr =
-      bitCastSetupResultLayout(layoutKind, srcShape, consumerLayoutAttr,
-                               outElemTyBitWidth, inElemTyBitWidth, uArch);
+  auto requiredResLayoutAttr = bitCastSetupResultLayout(
+      layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
 
   xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
 
+  int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+  int outElemTyBitWidth = resVecType.getElementType().getIntOrFloatBitWidth();
+
   // derive the source layout from the dominant layout and reduction dims
   auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
       requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
@@ -1250,107 +1255,96 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
 
-  LayoutInfo payloadLayout;
-  LayoutInfo maskLayout;
+  LayoutInfo srcLayoutInfo;
+  LayoutInfo maskLayoutInfo;
+  xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
   xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
   auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
-  const int subgroupSize = uArch->getSubgroupSize();
+  VectorType srcVecTy = storeScatter.getValueType();
+  VectorType maskTy =
+      llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
 
   if (hasParamsOfLayoutKind(anchorLayout)) {
-    payloadLayout = LayoutInfo(anchorLayout);
-    maskLayout = payloadLayout;
+    requiredAnchorLayoutAttr = LayoutInfo(anchorLayout);
   } else {
-    // Currently, for 2D StoreScatterOp we expect that the height dimension of
-    // the tensor descriptor is equal to the subgroup size. This is ensured by
-    // the op verifier.
-    VectorType payloadTy = storeScatter.getValueType();
-    if (!payloadTy) {
+
+    if (!srcVecTy) {
       storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
       return;
     }
-
-    if (layoutKind == xegpu::LayoutKind::InstData) {
-      const auto *uArchInstruction =
-          dyn_cast<xegpu::uArch::StoreScatterInstruction>(uArch->getInstruction(
-              xegpu::uArch::InstructionKind::StoreScatter));
-      const int subgroupSize = uArch->getSubgroupSize();
-      SmallVector<int> instDataUarch{subgroupSize};
-      if (payloadTy.getRank() != 1) {
-        if (payloadTy.getRank() != 2) {
-          storeScatter.emitWarning("Expected 2D payload for StoreScatterOp.");
-          return;
-        }
-        instDataUarch.push_back(
-            (std::min(static_cast<int>(payloadTy.getShape().back()),
-                      uArchInstruction->getMaxLaneLoadStoreSize())));
-      }
-      payloadLayout = LayoutInfo(
-          xegpu::LayoutAttr::get(storeScatter.getContext(), instDataUarch));
-    } else {
-      auto payloadShape = payloadTy.getShape();
-      if (payloadShape.size() > 1)
-        assert(payloadShape[0] == subgroupSize &&
-               "Expected the first dimension of 2D tensor descriptor to be "
-               "equal to "
-               "subgroup size.");
-      payloadLayout = getSIMTLayoutInfoScatterIO(payloadTy, uArch);
-    }
-
-    storeScatter.setLayoutAttr(
-        dyn_cast<xegpu::DistributeLayoutAttr>(payloadLayout.get()));
+    assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+    auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+    requiredAnchorLayoutAttr =
+        xegpu::storeScatterSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+    storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
   }
 
-  // If no user-defined anchor or we deal with a chunked op, set the default
-  // mask layout.
-  // Rank 1 data : Keep the mask layout aligned with data.
-  // Rank >1 data: Enforce the default xegpu 1D layout for mask.
-  if (!hasParamsOfLayoutKind(anchorLayout) ||
-      storeScatter.getValueType().getRank() > 1) {
-    if (layoutKind == xegpu::LayoutKind::InstData)
-      maskLayout = LayoutInfo(
-          xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize}));
-    else if (layoutKind == xegpu::LayoutKind::Lane)
-      maskLayout =
-          getDefaultSIMTLayoutInfo(storeScatter->getContext(), 1, subgroupSize);
+  srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
+  maskLayoutInfo = srcLayoutInfo;
+  if (maskTy.getRank() < srcVecTy.getRank()) {
+    assert((maskTy.getRank() == (srcVecTy.getRank() - 1)) &&
+           "Expecting mask vector only 1 dimension less than value vector.");
+    maskLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr.dropDim(-1));
+
+    // Propagate the payload operand layout
+    propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
+    // Propagate the destination (if tdesc) operand layout
+    if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
+      propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
+    // Propagate the new layout to the mask and optional offset operand.
+    propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
+    if (storeScatter.getOffsets())
+      propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
   }
 
-  // Propagate the payload operand layout
-  propagateIfChanged(operands[0], operands[0]->meet(payloadLayout));
-  // Propagate the destination (if tdesc) operand layout
-  if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
-    propagateIfChanged(operands[1], operands[1]->meet(payloadLayout));
-  // Propagate the new layout to the mask and optional offset operand.
-  propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
-  if (storeScatter.getOffsets())
-    propagateIfChanged(operands[3], operands[3]->meet(maskLayout));
-}
+  void LayoutInfoPropagation::visitLoadMatrixOp(
+      xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
+      ArrayRef<const LayoutInfoLattice *> results) {
+    Value resVal = loadMatrixOp.getRes();
+    unsigned index =
+        std::distance(loadMatrixOp.operand_begin(),
+                      llvm::find(loadMatrixOp.getOperands(), operand));
+
+    xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
+    LayoutInfo layout;
+    if (hasParamsOfLayoutKind(anchorLayout)) {
+      layout = LayoutInfo(anchorLayout);
+    } else {
+      VectorType resVecTy = llvm::cast<VectorType>(resVal.getType());
+      assert(resVecTy.getRank() == 2 &&
+             "Expecting 2D vector for store matrix.");
+      auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
+      auto requiredAnchorLayoutAttr =
+          xegpu::loadMatrixSetupAnchorLayout(layoutKind, resVecTy, uArch);
+      loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
+      layout = LayoutInfo(requiredAnchorLayoutAttr);
+    }
+  }
 
 // Store matrix is a flavor of scattered store for 2D shapes.
-void LayoutInfoPropagation::visitStoreMatrixOp(
-    xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
-    ArrayRef<const LayoutInfoLattice *> results) {
-  Value operand = storeMatrix.getData();
-  unsigned index =
-      std::distance(storeMatrix.operand_begin(),
-                    llvm::find(storeMatrix->getOperands(), operand));
+  void LayoutInfoPropagation::visitStoreMatrixOp(
+      xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
+      ArrayRef<const LayoutInfoLattice *> results) {
+
+    xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
+    LayoutInfo layout;
+    if (hasParamsOfLayoutKind(anchorLayout)) {
+      layout = LayoutInfo(anchorLayout);
+    } else {
+      VectorType srcVecTy =
+          llvm::cast<VectorType>(storeMatrix.getData().getType());
+      assert(srcVecTy.getRank() == 2 &&
+             "Expecting 2D vector for store matrix.");
+      auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+      auto requiredAnchorLayoutAttr =
+          xegpu::storeMatrixSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+      storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
+      layout = LayoutInfo(requiredAnchorLayoutAttr);
+    }
 
-  xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
-  LayoutInfo layout;
-  if (hasParamsOfLayoutKind(anchorLayout)) {
-    layout = LayoutInfo(anchorLayout);
-  } else {
-    VectorType payloadTy = llvm::cast<VectorType>(operand.getType());
-    assert(payloadTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
-    auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
-    auto requiredAnchorLayoutAttr =
-        storeMatrixSetupAnchorLayout(layoutKind, payloadTy, uArch);
-    storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
-    layout = LayoutInfo(requiredAnchorLayoutAttr);
+    propagateIfChanged(operands[0], operands[0]->meet(layout));
   }
 
-  propagateIfChanged(operands[index], operands[index]->meet(layout));
-}
-
 namespace {
 //===----------------------------------------------------------------------===//
 // RunLayoutInfoPropagation
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index d42f0d30ce8c0..fb806d5859c97 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -552,10 +552,11 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 ///    the consumer.
 
 xegpu::SliceAttr xegpu::reductionSetupResultLayout(
-    xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
+    xegpu::LayoutKind layoutKind, VectorType srcVecTy,
     DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
     const xegpu::uArch::uArch *uArch) {
 
+  auto srcShape = srcVecTy.getShape();
   xegpu::SliceAttr consumerSliceLayout =
       dyn_cast<xegpu::SliceAttr>(consumerLayout);
 
@@ -782,9 +783,13 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
 }
 
 xegpu::DistributeLayoutAttr xegpu::bitCastSetupResultLayout(
-    xegpu::LayoutKind layoutKind, ArrayRef<int64_t> srcShape,
-    DistributeLayoutAttr consumerLayout, int resElemTyBitWidth,
-    int srcElemTyBitWidth, const xegpu::uArch::uArch *uArch) {
+    xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
+    DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
+
+  int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
+  int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
+
+  ArrayRef<int64_t> srcShape = srcVecTy.getShape();
   SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
   SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
   SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
@@ -850,7 +855,7 @@ xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
   SmallVector<int> instData = {1, uArch->getSubgroupSize()};
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
-    assert(false &&
+    assert(true &&
            "subgroup layout assignment not supported yet for storeMatrix.");
     break;
   case xegpu::LayoutKind::InstData:
@@ -865,4 +870,87 @@ xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
     llvm_unreachable("unsupported layout kind");
   }
   return requiredLayout;
+}
+
+xegpu::DistributeLayoutAttr
+xegpu::loadMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
+                                   VectorType vectorTy,
+                                   const xegpu::uArch::uArch *uArch) {
+  xegpu::DistributeLayoutAttr requiredLayout;
+  SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+  switch (layoutKind) {
+  case xegpu::LayoutKind::Subgroup:
+    assert(true &&
+           "subgroup layout assignment not supported yet for loadMatrix.");
+    break;
+  case xegpu::LayoutKind::InstData:
+    requiredLayout = xegpu::LayoutAttr::get(vectorTy.getContext(), instData);
+    break;
+  case xegpu::LayoutKind::Lane:
+    requiredLayout = xegpu::LayoutAttr::get(
+        vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, 1});
+    break;
+  default:
+    llvm_unreachable("unsupported layout kind");
+  }
+  return requiredLayout;
+}
+
+static xegpu::DistributeLayoutAttr
+getDefaultSIMTLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
+                             int subgroupSize) {
+  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
+  if (rank == 1) {
+    return xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1});
+  }
+  return xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1});
+}
+
+xegpu::DistributeLayoutAttr
+xegpu::loadGatherSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+                                   const uArch::uArch *uArch) {}
+
+xegpu::DistributeLayoutAttr
+xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
+                                     const uArch::uArch *uArch) {
+
+  xegpu::DistributeLayoutAttr requiredLayout;
+  const int subgroupSize = uArch->getSubgroupSize();
+  const int spirVectorSize = 16; // vector size from SPRIV vector restriction
+
+  auto srcShape = srcVecTy.getShape();
+  int srcShapeSize = srcVecTy.getShape().size();
+
+  SmallVector<int64_t> instData(srcShapeSize);
+  SmallVector<int64_t> laneLayout(srcShapeSize);
+  SmallVector<int64_t> laneData(srcShapeSize);
+
+  const auto *uArchInstruction =
+      dyn_cast<xegpu::uArch::StoreScatterInstruction>(
+          uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+
+  switch (layoutKind) {
+  case xegpu::LayoutKind::Subgroup:
+    assert(false &&
+           "subgroup layout assignment not supported yet for loadMatrix.");
+    break;
+  case xegpu::LayoutKind::InstData:
+    assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
+                                       "tile at maximum at subgroup level.");
+    if (srcVecTy.getRank() == 1)
+      instData[0] = subgroupSize;
+    else {
+      instData[0] = std::min(srcShape[0], static_cast<int64_t>(spirVectorSize));
+      instData[1] = subgroupSize;
+    }
+    requiredLayout = xegpu::LayoutAttr::get(srcVecTy.getContext(), instData);
+    break;
+  case xegpu::LayoutKind::Lane:
+    requiredLayout = getDefaultSIMTLaneLayoutAttr(
+        srcVecTy.getContext(), srcVecTy.getRank(), subgroupSize);
+    break;
+  default:
+    llvm_unreachable("unsupported layout kind");
+  }
+  return requiredLayout;
 }
\ No newline at end of file

>From 3a4b60fa9b94e8b7fe2e586f6d311f0357db6390 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 07:03:20 +0000
Subject: [PATCH 12/35] code polish

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    |  17 +--
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 122 ++++++++--------
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 136 ++++++++++++++----
 3 files changed, 182 insertions(+), 93 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index fe23510845ae3..c083619292d0c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -117,21 +117,22 @@ DistributeLayoutAttr bitCastSetupResultLayout(
     LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
-// Setup the anchor layout attribute for a storeMatrix operation
-DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
-                                                  VectorType vectorTy,
-                                                  const uArch::uArch *uArch);
 xegpu::DistributeLayoutAttr
 xegpu::loadMatrixSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+                                   xegpu::DistributeLayoutAttr consumerLayout,
                                    const uArch::uArch *uArch);
 
-xegpu::DistributeLayoutAttr
-xegpu::loadGatherSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
-                                   const uArch::uArch *uArch);
+DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
+                                                  VectorType vectorTy,
+                                                  const uArch::uArch *uArch);
+
+xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
+    LayoutKind layoutKind, VectorType vectorTy, int chunkSize,
+    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
 xegpu::DistributeLayoutAttr
 xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
-                                     const uArch::uArch *uArch);
+                                     int chunkSize, const uArch::uArch *uArch);
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 6f356f490726b..0236ba3437b6b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -427,6 +427,14 @@ class LayoutInfoPropagation
                           ArrayRef<LayoutInfoLattice *> operands,
                           ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitLoadGatherOp(xegpu::LoadMatrixOp load,
+                         ArrayRef<LayoutInfoLattice *> operands,
+                         ArrayRef<const LayoutInfoLattice *> results);
+
+  void visitStoreScatterOp(xegpu::StoreMatrixOp store,
+                           ArrayRef<LayoutInfoLattice *> operands,
+                           ArrayRef<const LayoutInfoLattice *> results);
+
   bool hasParamsOfLayoutKind(xegpu::DistributeLayoutAttr anchorLayout);
 
 public:
@@ -1263,19 +1271,17 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   VectorType srcVecTy = storeScatter.getValueType();
   VectorType maskTy =
       llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
+  int chunkSize = storeScatter.getChunkSize().value_or(1);
 
   if (hasParamsOfLayoutKind(anchorLayout)) {
     requiredAnchorLayoutAttr = LayoutInfo(anchorLayout);
   } else {
-
     if (!srcVecTy) {
       storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
       return;
     }
-    assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
-    auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
-    requiredAnchorLayoutAttr =
-        xegpu::storeScatterSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+    requiredAnchorLayoutAttr = xegpu::storeScatterSetupAnchorLayout(
+        layoutKind, srcVecTy, chunkSize, uArch);
     storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
   }
 
@@ -1284,67 +1290,67 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   if (maskTy.getRank() < srcVecTy.getRank()) {
     assert((maskTy.getRank() == (srcVecTy.getRank() - 1)) &&
            "Expecting mask vector only 1 dimension less than value vector.");
-    maskLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr.dropDim(-1));
-
-    // Propagate the payload operand layout
-    propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
-    // Propagate the destination (if tdesc) operand layout
-    if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
-      propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
-    // Propagate the new layout to the mask and optional offset operand.
-    propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
-    if (storeScatter.getOffsets())
-      propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
+    // ToDO: Infer the proper mask layout based on the value layout.
+    maskLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
   }
+  // Propagate the payload operand layout
+  propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
+  // Propagate the destination (if tdesc) operand layout
+  if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
+    propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
+  // Propagate the new layout to the mask and optional offset operand.
+  propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
+  if (storeScatter.getOffsets())
+    propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
+}
 
-  void LayoutInfoPropagation::visitLoadMatrixOp(
-      xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
-      ArrayRef<const LayoutInfoLattice *> results) {
-    Value resVal = loadMatrixOp.getRes();
-    unsigned index =
-        std::distance(loadMatrixOp.operand_begin(),
-                      llvm::find(loadMatrixOp.getOperands(), operand));
-
-    xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
-    LayoutInfo layout;
-    if (hasParamsOfLayoutKind(anchorLayout)) {
-      layout = LayoutInfo(anchorLayout);
-    } else {
-      VectorType resVecTy = llvm::cast<VectorType>(resVal.getType());
-      assert(resVecTy.getRank() == 2 &&
-             "Expecting 2D vector for store matrix.");
-      auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
-      auto requiredAnchorLayoutAttr =
-          xegpu::loadMatrixSetupAnchorLayout(layoutKind, resVecTy, uArch);
-      loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
-      layout = LayoutInfo(requiredAnchorLayoutAttr);
-    }
+void LayoutInfoPropagation::visitLoadMatrixOp(
+    xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+
+  LayoutInfo resLayoutInfo = results[0]->getValue();
+  if (!resLayoutInfo.isAssigned())
+    return;
+  auto consumerLayoutAttr =
+      dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+
+  xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
+
+  // only need to set anchor layout, no need to porpagate to memdesc and offset
+  if (!hasParamsOfLayoutKind(anchorLayout)) {
+    VectorType resVecTy =
+        llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
+    assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+    auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
+    auto requiredAnchorLayoutAttr = xegpu::loadMatrixSetupAnchorLayout(
+        layoutKind, resVecTy, consumerLayoutAttr, uArch);
+    loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
   }
+}
 
 // Store matrix is a flavor of scattered store for 2D shapes.
-  void LayoutInfoPropagation::visitStoreMatrixOp(
-      xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
-      ArrayRef<const LayoutInfoLattice *> results) {
-
-    xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
-    LayoutInfo layout;
-    if (hasParamsOfLayoutKind(anchorLayout)) {
-      layout = LayoutInfo(anchorLayout);
-    } else {
-      VectorType srcVecTy =
-          llvm::cast<VectorType>(storeMatrix.getData().getType());
-      assert(srcVecTy.getRank() == 2 &&
-             "Expecting 2D vector for store matrix.");
-      auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
-      auto requiredAnchorLayoutAttr =
-          xegpu::storeMatrixSetupAnchorLayout(layoutKind, srcVecTy, uArch);
-      storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
-      layout = LayoutInfo(requiredAnchorLayoutAttr);
-    }
+void LayoutInfoPropagation::visitStoreMatrixOp(
+    xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
 
-    propagateIfChanged(operands[0], operands[0]->meet(layout));
+  xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
+  LayoutInfo layout;
+  if (hasParamsOfLayoutKind(anchorLayout)) {
+    layout = LayoutInfo(anchorLayout);
+  } else {
+    VectorType srcVecTy =
+        llvm::cast<VectorType>(storeMatrix.getData().getType());
+    assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+    auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
+    auto requiredAnchorLayoutAttr =
+        xegpu::storeMatrixSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+    storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
+    layout = LayoutInfo(requiredAnchorLayoutAttr);
   }
 
+  propagateIfChanged(operands[0], operands[0]->meet(layout));
+}
+
 namespace {
 //===----------------------------------------------------------------------===//
 // RunLayoutInfoPropagation
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index fb806d5859c97..3968f0f2db0af 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -875,20 +875,33 @@ xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
 xegpu::DistributeLayoutAttr
 xegpu::loadMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
                                    VectorType vectorTy,
+                                   xegpu::DistributeLayoutAttr consumerLayout,
                                    const xegpu::uArch::uArch *uArch) {
   xegpu::DistributeLayoutAttr requiredLayout;
   SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+
+  SmallVector<int64_t> consumerSgLayout =
+      consumerLayout.getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> consumerSgData =
+      consumerLayout.getEffectiveSgDataAsInt();
+  SmallVector<int32_t> sgLayout32(consumerSgLayout.begin(),
+                                  consumerSgLayout.end());
+  SmallVector<int32_t> sgData32(consumerSgData.begin(), consumerSgData.end());
+
+  auto context = vectorTy.getContext();
+
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
-    assert(true &&
-           "subgroup layout assignment not supported yet for loadMatrix.");
+    requiredLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, sgLayout32),
+        DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
     break;
   case xegpu::LayoutKind::InstData:
-    requiredLayout = xegpu::LayoutAttr::get(vectorTy.getContext(), instData);
+    requiredLayout = xegpu::LayoutAttr::get(context, instData);
     break;
   case xegpu::LayoutKind::Lane:
-    requiredLayout = xegpu::LayoutAttr::get(
-        vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, 1});
+    requiredLayout =
+        xegpu::LayoutAttr::get(context, {1, uArch->getSubgroupSize()}, {1, 1});
     break;
   default:
     llvm_unreachable("unsupported layout kind");
@@ -897,37 +910,99 @@ xegpu::loadMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
 }
 
 static xegpu::DistributeLayoutAttr
-getDefaultSIMTLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
-                             int subgroupSize) {
+getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
+                         const xegpu::uArch::uArch *uArch) {
   assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
   if (rank == 1) {
-    return xegpu::LayoutAttr::get(ctx, {subgroupSize}, {1});
+    return xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1});
   }
-  return xegpu::LayoutAttr::get(ctx, {1, subgroupSize}, {1, 1});
+  return xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1});
 }
 
-xegpu::DistributeLayoutAttr
-xegpu::loadGatherSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
-                                   const uArch::uArch *uArch) {}
+xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
+    LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
+    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
+
+  xegpu::DistributeLayoutAttr requiredLayout;
+  const int subgroupSize = uArch->getSubgroupSize();
+  const int spirVectorSize = 16; // vector size from SPRIV vector restriction
+
+  auto resShape = resVecTy.getShape();
+  int resShapeSize = resShape.size();
+  SmallVector<int64_t> instData(subgroupSize);
+  auto context = resVecTy.getContext();
+
+  const auto *uArchInstruction =
+      dyn_cast<xegpu::uArch::StoreScatterInstruction>(
+          uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+
+  SmallVector<int64_t> consumerSgLayout =
+      consumerLayout.getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> consumerSgData =
+      consumerLayout.getEffectiveSgDataAsInt();
+  SmallVector<int32_t> sgLayout32(consumerSgLayout.begin(),
+                                  consumerSgLayout.end());
+  SmallVector<int32_t> sgData32(consumerSgData.begin(), consumerSgData.end());
+
+  SmallVector<int64_t> consumerInstData =
+      consumerLayout.getEffectiveInstDataAsInt();
+  SmallVector<int32_t> instData32;
+
+  switch (layoutKind) {
+  case xegpu::LayoutKind::Subgroup:
+    requiredLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, sgLayout32),
+        DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+    break;
+  case xegpu::LayoutKind::InstData:
+    if (resVecTy.getRank() == 1) {
+      instData[0] = subgroupSize;
+    } else {
+      assert((resVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
+                                         "tile at maximum at subgroup level.");
+      if (chunkSize == 1) {
+        instData[0] =
+            std::min(resShape[0], static_cast<int64_t>(spirVectorSize));
+        instData[0] = std::min(instData[0], consumerInstData[0]);
+        instData[1] = subgroupSize;
+      } else {
+        instData[0] = subgroupSize;
+        instData[1] = std::min(
+            resShape[1],
+            static_cast<int64_t>(uArchInstruction->getMaxLaneLoadStoreSize()));
+        instData[1] = std::min(instData[1], consumerInstData[1]);
+      }
+    }
+    instData32 = SmallVector<int32_t>(instData.begin(), instData.end());
+    requiredLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, instData32));
+    break;
+  case xegpu::LayoutKind::Lane:
+    requiredLayout =
+        getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
+    break;
+  default:
+    llvm_unreachable("unsupported layout kind");
+  }
+  return requiredLayout;
+}
 
 xegpu::DistributeLayoutAttr
 xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
-                                     const uArch::uArch *uArch) {
+                                     int chunkSize, const uArch::uArch *uArch) {
 
   xegpu::DistributeLayoutAttr requiredLayout;
   const int subgroupSize = uArch->getSubgroupSize();
   const int spirVectorSize = 16; // vector size from SPRIV vector restriction
 
   auto srcShape = srcVecTy.getShape();
-  int srcShapeSize = srcVecTy.getShape().size();
-
-  SmallVector<int64_t> instData(srcShapeSize);
-  SmallVector<int64_t> laneLayout(srcShapeSize);
-  SmallVector<int64_t> laneData(srcShapeSize);
+  int srcShapeSize = srcShape.size();
+  SmallVector<int64_t> instData(subgroupSize);
 
   const auto *uArchInstruction =
       dyn_cast<xegpu::uArch::StoreScatterInstruction>(
           uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+  auto context = srcVecTy.getContext();
 
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
@@ -935,19 +1010,26 @@ xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
            "subgroup layout assignment not supported yet for loadMatrix.");
     break;
   case xegpu::LayoutKind::InstData:
-    assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
-                                       "tile at maximum at subgroup level.");
-    if (srcVecTy.getRank() == 1)
+    if (srcVecTy.getRank() == 1) {
       instData[0] = subgroupSize;
-    else {
-      instData[0] = std::min(srcShape[0], static_cast<int64_t>(spirVectorSize));
-      instData[1] = subgroupSize;
+    } else {
+      assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
+                                         "tile at maximum at subgroup level.");
+      if (chunkSize == 1) {
+        instData[0] =
+            std::min(srcShape[0], static_cast<int64_t>(spirVectorSize));
+        instData[1] = subgroupSize;
+      } else {
+        instData[0] = subgroupSize;
+        instData[1] = std::min(
+            srcShape[1],
+            static_cast<int64_t>(uArchInstruction->getMaxLaneLoadStoreSize()));
+      }
     }
-    requiredLayout = xegpu::LayoutAttr::get(srcVecTy.getContext(), instData);
     break;
   case xegpu::LayoutKind::Lane:
-    requiredLayout = getDefaultSIMTLaneLayoutAttr(
-        srcVecTy.getContext(), srcVecTy.getRank(), subgroupSize);
+    requiredLayout =
+        getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
     break;
   default:
     llvm_unreachable("unsupported layout kind");

>From ef5a279afb8394e0bdfd9c46de48e4fb4324ba67 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 20:08:28 +0000
Subject: [PATCH 13/35] add scatter IO mask handling

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    |  26 ++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  52 +++----
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 135 +++++++-----------
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 130 ++++++++++++-----
 4 files changed, 184 insertions(+), 159 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index c083619292d0c..bba2b1e06129a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -90,6 +90,12 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
 
+/// Infers the source layout attribute for mask operand of scatter IO operation
+/// given the result layout attribute, value shape, and mask shape.
+DistributeLayoutAttr inferScatterIOMaskLayout(DistributeLayoutAttr resLayout,
+                                              ArrayRef<int64_t> valShape,
+                                              ArrayRef<int64_t> maskShape);
+
 /// Sets up layout for reduction operations by creating a SliceAttr for the
 /// result.
 ///
@@ -98,11 +104,11 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
 /// consumer's preferred layout. This minimizes data redistribution overhead.
 /// The SliceAttr for the result is then created based on the derived source
 /// layout and the specified reduction dimensions.
-SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
-                                     VectorType srcVectorTy,
-                                     DistributeLayoutAttr consumerLayout,
-                                     SmallVector<int64_t> reductionDims,
-                                     const uArch::uArch *uArch);
+SliceAttr setupMultiReductionResultLayout(xegpu::LayoutKind layoutKind,
+                                          VectorType srcVectorTy,
+                                          DistributeLayoutAttr consumerLayout,
+                                          SmallVector<int64_t> reductionDims,
+                                          const uArch::uArch *uArch);
 
 /// Setup the result layout attribute for a bitcast operation based on element
 /// type bitwidths. This ensures the source layout can always be derived from
@@ -113,25 +119,25 @@ SliceAttr reductionSetupResultLayout(xegpu::LayoutKind layoutKind,
 /// (inst_data, lane_data) are scaled up by the bitwidth ratio. This
 /// maintains the invariant that the source layout can be recovered by inverse
 /// scaling during layout inference.
-DistributeLayoutAttr bitCastSetupResultLayout(
+DistributeLayoutAttr setupBitCastResultLayout(
     LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
 xegpu::DistributeLayoutAttr
-xegpu::loadMatrixSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+xegpu::setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
                                    xegpu::DistributeLayoutAttr consumerLayout,
                                    const uArch::uArch *uArch);
 
-DistributeLayoutAttr storeMatrixSetupAnchorLayout(LayoutKind layoutKind,
+DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
                                                   VectorType vectorTy,
                                                   const uArch::uArch *uArch);
 
-xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
+xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
     LayoutKind layoutKind, VectorType vectorTy, int chunkSize,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
 xegpu::DistributeLayoutAttr
-xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
                                      int chunkSize, const uArch::uArch *uArch);
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 5bf57376cbbaf..78d1ad50dd1dd 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -513,15 +513,11 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
   SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
 
   DenseI32ArrayAttr orderAttr = getOrder();
-  SmallVector<int64_t> order;
+  SmallVector<int64_t> orderVec;
   if (orderAttr && !orderAttr.empty()) {
-    order = llvm::to_vector(
+    orderVec = llvm::to_vector(
         llvm::map_range(orderAttr.asArrayRef(),
                         [](int32_t idx) { return static_cast<int64_t>(idx); }));
-  } else {
-    // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
-    order =
-        llvm::to_vector(llvm::reverse(llvm::seq<int64_t>(0, sgLayout.size())));
   }
 
   SmallVector<int64_t> collapsedSgLayout;
@@ -555,22 +551,6 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
     collapsedOrder.push_back(collapsedOrderValue);
   }
 
-  // go through the values inside collapsedOrder, and re-map the order values to
-  // be in range of [0, N-1] where N is the number of dimensions in collapsed
-  // shape
-  int64_t orderSize = static_cast<int64_t>(collapsedOrder.size());
-  SmallVector<int64_t> remappedOrder(orderSize, -1);
-  for (int64_t i = 0; i < orderSize; ++i) {
-    int64_t originalOrderValue = collapsedOrder[i];
-    // count how many values in collapsedOrder are less than originalOrderValue
-    int64_t count = 0;
-    for (int64_t j = 0; j < orderSize; ++j) {
-      if (collapsedOrder[j] < originalOrderValue)
-        count++;
-    }
-    remappedOrder[i] = count;
-  }
-
   // Create collapsed layout
   SmallVector<int32_t> collapsedSgLayout32(collapsedSgLayout.begin(),
                                            collapsedSgLayout.end());
@@ -582,8 +562,28 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
                                              collapsedLaneLayout.end());
   SmallVector<int32_t> collapsedLaneData32(collapsedLaneData.begin(),
                                            collapsedLaneData.end());
-  SmallVector<int32_t> remappedOrder32(remappedOrder.begin(),
-                                       remappedOrder.end());
+
+  // go through the values inside collapsedOrder, and re-map the order values to
+  // be in range of [0, N-1] where N is the number of dimensions in collapsed
+  // shape
+  SmallVector<int32_t> remappedOrder32;
+  if (!orderVec.empty()) {
+    int64_t orderSize = static_cast<int64_t>(collapsedOrder.size());
+    SmallVector<int64_t> remappedOrder(orderSize, -1);
+    for (int64_t i = 0; i < orderSize; ++i) {
+      int64_t originalOrderValue = collapsedOrder[i];
+      // count how many values in collapsedOrder are less than
+      // originalOrderValue
+      int64_t count = 0;
+      for (int64_t j = 0; j < orderSize; ++j) {
+        if (collapsedOrder[j] < originalOrderValue)
+          count++;
+      }
+      remappedOrder[i] = count;
+    }
+    remappedOrder32 =
+        SmallVector<int32_t>(remappedOrder.begin(), remappedOrder.end());
+  }
 
   auto collapsedLayout = xegpu::LayoutAttr::get(
       getContext(), DenseI32ArrayAttr::get(getContext(), collapsedSgLayout32),
@@ -591,7 +591,9 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
       DenseI32ArrayAttr::get(getContext(), collapsedInstData32),
       DenseI32ArrayAttr::get(getContext(), collapsedLaneLayout32),
       DenseI32ArrayAttr::get(getContext(), collapsedLaneData32),
-      DenseI32ArrayAttr::get(getContext(), remappedOrder32));
+      remappedOrder32.empty()
+          ? nullptr
+          : DenseI32ArrayAttr::get(getContext(), remappedOrder32));
   return collapsedLayout;
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 0236ba3437b6b..e6fa8bfd12474 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1150,88 +1150,43 @@ void LayoutInfoPropagation::visitLoadGatherOp(
     xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
 
-  LayoutInfo loadLayout;
-  LayoutInfo maskLayout;
+  xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
+  xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
   auto uArch = getUArch(getChipStr(load).value_or(""));
-  const int subgroupSize = uArch->getSubgroupSize();
-  xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
-  if (hasParamsOfLayoutKind(anchorLayout)) {
-    loadLayout = LayoutInfo(anchorLayout);
-    maskLayout = loadLayout;
-  } else {
-    LayoutInfo valueLayout = results[0]->getValue();
-    // Need the layout of the value to propagate to the tensor descriptor.
-    if (!valueLayout.isAssigned())
-      return;
+  auto subgroupSize = uArch->getSubgroupSize();
+  VectorType resVecTy = load.getValueType();
+  VectorType maskTy = llvm::dyn_cast<VectorType>(load.getMask().getType());
+  int chunkSize = load.getChunkSize().value_or(1);
 
-    auto resAttr = dyn_cast<xegpu::DistributeLayoutAttr>(valueLayout.get());
-    auto instDataIncoming = resAttr.getEffectiveInstDataAsInt();
-    if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
-      instDataIncoming = SmallVector<int64_t>(
-          cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent())
-              .getInstData()
-              .asArrayRef());
-
-    VectorType payloadTy = load.getValueType();
-    if (!payloadTy) {
+  if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
+    requiredAnchorLayoutAttr = anchorLayoutAttr;
+  } else {
+    if (!resVecTy) {
       load.emitWarning("Not propagating, non-vector payload supplied.");
       return;
     }
-    const auto *uArchInstruction =
-        dyn_cast<xegpu::uArch::LoadGatherInstruction>(
-            uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
-
-    // Check if value inst_data complies with uArch
-    if (layoutKind == xegpu::LayoutKind::InstData) {
-      // Each lane loads either one element
-      SmallVector<int> instDataUarch{subgroupSize};
-      // Or multiple elements as 2D with lane's elements in the inner dimension
-      if (payloadTy.getRank() != 1) {
-        if (payloadTy.getRank() != 2) {
-          load.emitWarning("Expected 2D payload for LoadGatherOp.");
-          return;
-        }
-        instDataUarch.push_back(
-            (std::min(static_cast<int>(payloadTy.getShape().back()),
-                      uArchInstruction->getMaxLaneLoadStoreSize())));
-      }
-      // If inst data does not match, enforce the uArch-based one
-      if (!llvm::equal(instDataIncoming, instDataUarch)) {
-        xegpu::LayoutAttr sourceAttr = dyn_cast<xegpu::LayoutAttr>(resAttr);
-        if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr)) {
-          sourceAttr = cast<xegpu::LayoutAttr>(sliceAttr.flatten().getParent());
-        }
-        assert(sourceAttr);
-        xegpu::DistributeLayoutAttr updatedLayoutAttr = xegpu::LayoutAttr::get(
-            load.getContext(), sourceAttr.getSgLayout(), sourceAttr.getSgData(),
-            DenseI32ArrayAttr::get(load.getContext(), instDataUarch),
-            sourceAttr.getLaneLayout(), sourceAttr.getLaneData(),
-            sourceAttr.getOrder());
-
-        if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(resAttr))
-          updatedLayoutAttr = xegpu::SliceAttr::get(
-              load.getContext(), updatedLayoutAttr, sliceAttr.getDims());
-        valueLayout = LayoutInfo(updatedLayoutAttr);
-      }
-    }
-    loadLayout = valueLayout;
-    load.setLayoutAttr(dyn_cast<xegpu::DistributeLayoutAttr>(loadLayout.get()));
+    requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
+        layoutKind, resVecTy, chunkSize, uArch);
+    load.setLayoutAttr(requiredAnchorLayoutAttr);
   }
 
-  // If no user-defined anchor or we deal with a chunked op, set the default
-  // mask layout.
-  // Rank 1 data : Keep the mask layout aligned with data.
-  // Rank >1 data: Enforce the default xegpu 1D layout for mask.
-  if (!hasParamsOfLayoutKind(anchorLayout) ||
-      load.getValueType().getRank() > 1) {
+  auto maskLayoutAttr = anchorLayoutAttr;
+  // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
+  // layout for mask.
+  if (chunkSize > 1) {
     if (layoutKind == xegpu::LayoutKind::InstData)
-      maskLayout = LayoutInfo(
-          xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}));
+      maskLayoutAttr =
+          xegpu::LayoutAttr::get(load->getContext(), {subgroupSize});
     else if (layoutKind == xegpu::LayoutKind::Lane)
-      maskLayout =
-          getDefaultSIMTLayoutInfo(load->getContext(), 1, subgroupSize);
+      maskLayoutAttr =
+          xegpu::LayoutAttr::get(load->getContext(), {subgroupSize}, {1});
+    else
+      assert(false &&
+             "chunked StoreScatterOp should not be used at workgroup level");
   }
 
+  LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+
   // Propagate the new layout to the tensor descriptor operand.
   if (isa<xegpu::TensorDescType>(load.getSourceType()))
     propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
@@ -1263,36 +1218,45 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
 
-  LayoutInfo srcLayoutInfo;
-  LayoutInfo maskLayoutInfo;
   xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
-  xegpu::DistributeLayoutAttr anchorLayout = storeScatter.getLayoutAttr();
+  xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
   auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
+  auto subgroupSize = uArch->getSubgroupSize();
   VectorType srcVecTy = storeScatter.getValueType();
   VectorType maskTy =
       llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
   int chunkSize = storeScatter.getChunkSize().value_or(1);
 
-  if (hasParamsOfLayoutKind(anchorLayout)) {
-    requiredAnchorLayoutAttr = LayoutInfo(anchorLayout);
+  if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
+    requiredAnchorLayoutAttr = anchorLayoutAttr;
   } else {
     if (!srcVecTy) {
       storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
       return;
     }
-    requiredAnchorLayoutAttr = xegpu::storeScatterSetupAnchorLayout(
+    requiredAnchorLayoutAttr = xegpu::setupStoreScatterAnchorLayout(
         layoutKind, srcVecTy, chunkSize, uArch);
     storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
   }
 
-  srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
-  maskLayoutInfo = srcLayoutInfo;
-  if (maskTy.getRank() < srcVecTy.getRank()) {
-    assert((maskTy.getRank() == (srcVecTy.getRank() - 1)) &&
-           "Expecting mask vector only 1 dimension less than value vector.");
-    // ToDO: Infer the proper mask layout based on the value layout.
-    maskLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
+  LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
+  auto maskLayoutAttr = anchorLayoutAttr;
+  // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
+  // layout for mask.
+  if (chunkSize > 1) {
+    if (layoutKind == xegpu::LayoutKind::InstData)
+      maskLayoutAttr =
+          xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
+    else if (layoutKind == xegpu::LayoutKind::Lane)
+      maskLayoutAttr = xegpu::LayoutAttr::get(storeScatter->getContext(),
+                                              {subgroupSize}, {1});
+    else
+      assert(false &&
+             "chunked StoreScatterOp should not be used at workgroup level");
   }
+
+  LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+
   // Propagate the payload operand layout
   propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
   // Propagate the destination (if tdesc) operand layout
@@ -1316,7 +1280,8 @@ void LayoutInfoPropagation::visitLoadMatrixOp(
 
   xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
 
-  // only need to set anchor layout, no need to porpagate to memdesc and offset
+  // only need to set anchor layout, no need to porpagate to memdesc and
+  // offset
   if (!hasParamsOfLayoutKind(anchorLayout)) {
     VectorType resVecTy =
         llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 3968f0f2db0af..8d7d1b96b363e 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -532,6 +532,61 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   return nullptr;
 }
 
+xegpu::DistributeLayoutAttr
+inferScatterIOMaskLayout(xegpu::DistributeLayoutAttr resLayout,
+                         ArrayRef<int64_t> valShape,
+                         ArrayRef<int64_t> maskShape) {
+
+  xegpu::LayoutAttr resPlainLayout = dyn_cast<xegpu::LayoutAttr>(resLayout);
+  assert(resPlainLayout &&
+         "Expecting plain layout for scatter IO mask inference.");
+  xegpu::LayoutAttr maskLayout = resPlainLayout;
+
+  if (maskShape.size() < valShape.size()) {
+    int valShapeSize = valShape.size();
+    int maskShapeSize = maskShape.size();
+    assert((maskShapeSize == (valShapeSize - 1)) &&
+           "Expecting mask vector only 1 dimension less than value vector.");
+
+    maskLayout = resPlainLayout.setUnitDimData(valShapeSize - 1);
+    maskLayout = maskLayout.collapseDim(valShapeSize - 1);
+    // SmallVector<int> sgLayout(valShapeSize);
+    // sgLayout = resPlainLayout.getSgLayout();
+    // SmallVector<int> sgData(valShapeSize) = resPlainLayout.getSgData();
+    // SmallVector<int> instData(valShapeSize) = resPlainLayout.getSgLayout();
+    // SmallVector<int> laneLayout(valShapeSize) =
+    // resPlainLayout.getLaneLayout(); SmallVector<int> laneData(valShapeSize) =
+    // resPlainLayout.getSgLayout(); SmallVector<int> order(valShapeSize) =
+    // resPlainLayout.getOrder();
+
+    // // drop the innermost dimension for mask layout
+    // sgLayout.pop_back();
+    // sgData.pop_back();
+    // instData.pop_back();
+    // laneLayout.pop_back();
+    // laneData.pop_back();
+    // order.pop_back();
+
+    // SmallVector<int64_t> remappedOrder(maskShapeSize, -1);
+    // for (int64_t i = 0; i < maskShapeSize; ++i) {
+    //   int64_t originalOrderValue = order[i];
+    //   // count how many values in collapsed Order are less than
+    //   // originalOrderValue
+    //   int64_t count = 0;
+    //   for (int64_t j = 0; j < maskShapeSize; ++j) {
+    //     if (order[j] < originalOrderValue)
+    //       count++;
+    //   }
+    //   remappedOrder[i] = count;
+    // }
+
+    // maskLayout =
+    //     xegpu::LayoutAttr::get(resLayout.getContext(), sgLayout, sgData,
+    //                            instData, laneLayout, laneData, order);
+  }
+  return maskLayout;
+}
+
 /// Sets up layout for reduction operations by creating a SliceAttr for the
 /// result.
 ///
@@ -551,7 +606,7 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 ///    reduced result stays with the same subgroup distribution as expected by
 ///    the consumer.
 
-xegpu::SliceAttr xegpu::reductionSetupResultLayout(
+xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy,
     DistributeLayoutAttr consumerLayout, SmallVector<int64_t> reductionDims,
     const xegpu::uArch::uArch *uArch) {
@@ -782,7 +837,7 @@ xegpu::SliceAttr xegpu::reductionSetupResultLayout(
   return resLayout;
 }
 
-xegpu::DistributeLayoutAttr xegpu::bitCastSetupResultLayout(
+xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
     DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
 
@@ -847,7 +902,7 @@ xegpu::DistributeLayoutAttr xegpu::bitCastSetupResultLayout(
 }
 
 xegpu::DistributeLayoutAttr
-xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
+xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
                                     VectorType vectorTy,
                                     const xegpu::uArch::uArch *uArch) {
 
@@ -873,28 +928,18 @@ xegpu::storeMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
 }
 
 xegpu::DistributeLayoutAttr
-xegpu::loadMatrixSetupAnchorLayout(xegpu::LayoutKind layoutKind,
+xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
                                    VectorType vectorTy,
                                    xegpu::DistributeLayoutAttr consumerLayout,
                                    const xegpu::uArch::uArch *uArch) {
   xegpu::DistributeLayoutAttr requiredLayout;
   SmallVector<int> instData = {1, uArch->getSubgroupSize()};
 
-  SmallVector<int64_t> consumerSgLayout =
-      consumerLayout.getEffectiveSgLayoutAsInt();
-  SmallVector<int64_t> consumerSgData =
-      consumerLayout.getEffectiveSgDataAsInt();
-  SmallVector<int32_t> sgLayout32(consumerSgLayout.begin(),
-                                  consumerSgLayout.end());
-  SmallVector<int32_t> sgData32(consumerSgData.begin(), consumerSgData.end());
-
   auto context = vectorTy.getContext();
 
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
-    requiredLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, sgLayout32),
-        DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+    requiredLayout = consumerLayout;
     break;
   case xegpu::LayoutKind::InstData:
     requiredLayout = xegpu::LayoutAttr::get(context, instData);
@@ -919,7 +964,7 @@ getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
   return xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1});
 }
 
-xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
+xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
     LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
 
@@ -936,23 +981,13 @@ xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
       dyn_cast<xegpu::uArch::StoreScatterInstruction>(
           uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
 
-  SmallVector<int64_t> consumerSgLayout =
-      consumerLayout.getEffectiveSgLayoutAsInt();
-  SmallVector<int64_t> consumerSgData =
-      consumerLayout.getEffectiveSgDataAsInt();
-  SmallVector<int32_t> sgLayout32(consumerSgLayout.begin(),
-                                  consumerSgLayout.end());
-  SmallVector<int32_t> sgData32(consumerSgData.begin(), consumerSgData.end());
-
   SmallVector<int64_t> consumerInstData =
       consumerLayout.getEffectiveInstDataAsInt();
   SmallVector<int32_t> instData32;
 
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
-    requiredLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, sgLayout32),
-        DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+    requiredLayout = consumerLayout;
     break;
   case xegpu::LayoutKind::InstData:
     if (resVecTy.getRank() == 1) {
@@ -961,9 +996,7 @@ xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
       assert((resVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
                                          "tile at maximum at subgroup level.");
       if (chunkSize == 1) {
-        instData[0] =
-            std::min(resShape[0], static_cast<int64_t>(spirVectorSize));
-        instData[0] = std::min(instData[0], consumerInstData[0]);
+        instData[0] = 1;
         instData[1] = subgroupSize;
       } else {
         instData[0] = subgroupSize;
@@ -978,8 +1011,18 @@ xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
         context, DenseI32ArrayAttr::get(context, instData32));
     break;
   case xegpu::LayoutKind::Lane:
-    requiredLayout =
-        getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
+    if (chunkSize == 1)
+      requiredLayout =
+          getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
+    else {
+      assert((resVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
+                                          "tile at maximum at subgroup level.");
+      assert(resShape[1] <= static_cast<int64_t>(
+                                uArchInstruction->getMaxLaneLoadStoreSize()) &&
+             "StoreScatterOp lane size exceeds max lane load/store size.");
+      requiredLayout = xegpu::LayoutAttr::get(
+          context, {subgroupSize, 1}, {1, static_cast<int>(resShape[1])});
+    }
     break;
   default:
     llvm_unreachable("unsupported layout kind");
@@ -988,7 +1031,7 @@ xegpu::DistributeLayoutAttr xegpu::loadGatherSetupAnchorLayout(
 }
 
 xegpu::DistributeLayoutAttr
-xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
+xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
                                      int chunkSize, const uArch::uArch *uArch) {
 
   xegpu::DistributeLayoutAttr requiredLayout;
@@ -1013,11 +1056,10 @@ xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
     if (srcVecTy.getRank() == 1) {
       instData[0] = subgroupSize;
     } else {
-      assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
-                                         "tile at maximum at subgroup level.");
+      assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
+                                          "tile at maximum at subgroup level.");
       if (chunkSize == 1) {
-        instData[0] =
-            std::min(srcShape[0], static_cast<int64_t>(spirVectorSize));
+        instData[0] = 1;
         instData[1] = subgroupSize;
       } else {
         instData[0] = subgroupSize;
@@ -1028,8 +1070,18 @@ xegpu::storeScatterSetupAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
     }
     break;
   case xegpu::LayoutKind::Lane:
-    requiredLayout =
-        getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
+    if (chunkSize == 1)
+      requiredLayout =
+          getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
+    else {
+      assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
+                                         "tile at maximum at subgroup level.");
+      assert(srcShape[1] <= static_cast<int64_t>(
+                                uArchInstruction->getMaxLaneLoadStoreSize()) &&
+             "StoreScatterOp lane size exceeds max lane load/store size.");
+      requiredLayout = xegpu::LayoutAttr::get(
+          context, {subgroupSize, 1}, {1, static_cast<int>(srcShape[1])});
+    }
     break;
   default:
     llvm_unreachable("unsupported layout kind");

>From a717336e6c639abf8b3636f300e119c9cefc4ea2 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 27 Jan 2026 22:48:06 +0000
Subject: [PATCH 14/35] setting correct inst layout for 1dreduction example

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    |  34 ++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |   3 +-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 126 +++++++++++++--
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 145 +++++++++---------
 4 files changed, 200 insertions(+), 108 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index bba2b1e06129a..95122104f22e5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -74,8 +74,8 @@ DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout,
 /// Infers the source layout attribute for a reduction operation given the
 /// result layout attribute and reduced dims.
 DistributeLayoutAttr
-inferReductionSourceLayout(DistributeLayoutAttr resLayout,
-                           SmallVector<int64_t> reduceDims);
+inferMultiReductionSourceLayout(DistributeLayoutAttr resLayout,
+                                SmallVector<int64_t> reduceDims);
 
 /// Infers the source layout attribute for a bitcast operation given the
 /// result layout attribute, result element type bitwidth, and source element
@@ -90,12 +90,6 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
 
-/// Infers the source layout attribute for mask operand of scatter IO operation
-/// given the result layout attribute, value shape, and mask shape.
-DistributeLayoutAttr inferScatterIOMaskLayout(DistributeLayoutAttr resLayout,
-                                              ArrayRef<int64_t> valShape,
-                                              ArrayRef<int64_t> maskShape);
-
 /// Sets up layout for reduction operations by creating a SliceAttr for the
 /// result.
 ///
@@ -123,22 +117,24 @@ DistributeLayoutAttr setupBitCastResultLayout(
     LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
-xegpu::DistributeLayoutAttr
-xegpu::setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
-                                   xegpu::DistributeLayoutAttr consumerLayout,
-                                   const uArch::uArch *uArch);
+DistributeLayoutAttr
+setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+                            DistributeLayoutAttr consumerLayout,
+                            const uArch::uArch *uArch);
 
 DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
                                                   VectorType vectorTy,
                                                   const uArch::uArch *uArch);
 
-xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
-    LayoutKind layoutKind, VectorType vectorTy, int chunkSize,
-    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
-
-xegpu::DistributeLayoutAttr
-xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
-                                     int chunkSize, const uArch::uArch *uArch);
+DistributeLayoutAttr
+setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+                            int chunkSize, DistributeLayoutAttr consumerLayout,
+                            const uArch::uArch *uArch);
+
+DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind,
+                                                   VectorType vectorTy,
+                                                   int chunkSize,
+                                                   const uArch::uArch *uArch);
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 78d1ad50dd1dd..c7b73dd59767e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -540,7 +540,8 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
       collapsedInst *= instData[dimIdx];
       collapsedLaneL *= laneLayout[dimIdx];
       collapsedLaneD *= laneData[dimIdx];
-      collapsedOrderValue = order[dimIdx]; // take the last one's order
+      if (!orderVec.empty())
+        collapsedOrderValue = orderVec[dimIdx]; // take the last one's order
     }
 
     collapsedSgLayout.push_back(collapsedSg);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index e6fa8bfd12474..44127d27de522 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -695,7 +695,7 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   // propagated from consumer op, the conflict is resolved in later phase by
   // converting the required result layout to the consumer layout
   auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
-  auto requiredResLayoutAttr = xegpu::reductionSetupResultLayout(
+  auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
       layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
   LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
                     << requiredResLayoutAttr << "\n");
@@ -713,8 +713,8 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
              results[0]->getValue().print(llvm::dbgs()); llvm::dbgs() << "\n");
 
   // derive the source layout from the dominant layout and reduction dims
-  auto srcLayoutAttr =
-      xegpu::inferReductionSourceLayout(requiredResLayoutAttr, reductionDims);
+  auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
+      requiredResLayoutAttr, reductionDims);
 
   LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcLayoutAttr = "
                     << srcLayoutAttr << "\n");
@@ -1129,7 +1129,7 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   auto consumerLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
   auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
-  auto requiredResLayoutAttr = bitCastSetupResultLayout(
+  auto requiredResLayoutAttr = setupBitCastResultLayout(
       layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
 
   xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
@@ -1144,6 +1144,40 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
 }
 
+// /// For vector::BitCastOp, the lane_data of the source layout is changed
+// based
+// /// on the bit width of the source and result types.
+// void LayoutInfoPropagation::visitVectorInsertStridedSliceOp(
+//     vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
+//     ArrayRef<const LayoutInfoLattice *> results) {
+//   // Need the layout of bitcast result to propagate to the operands.
+//   LayoutInfo resLayoutInfo = results[0]->getValue();
+//   if (!resLayoutInfo.isAssigned())
+//     return;
+
+//   auto srcVecType = bitcast.getSourceVectorType();
+//   auto resVecType = bitcast.getResultVectorType();
+
+//   auto consumerLayoutAttr =
+//       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+//   auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
+//   auto requiredResLayoutAttr = setupInsertStridedSliceResultLayout(
+//       layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
+
+//   xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
+
+//   int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
+//   int outElemTyBitWidth =
+//   resVecType.getElementType().getIntOrFloatBitWidth();
+
+//   // derive the source layout from the dominant layout and reduction dims
+//   auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
+//       requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
+
+//   propagateIfChanged(operands[0],
+//   operands[0]->meet(LayoutInfo(srcLayoutAttr)));
+// }
+
 /// Propagate the layout of the result to the tensor descriptor, mask and offset
 /// operands in LoadGatherOp.
 void LayoutInfoPropagation::visitLoadGatherOp(
@@ -1155,9 +1189,14 @@ void LayoutInfoPropagation::visitLoadGatherOp(
   auto uArch = getUArch(getChipStr(load).value_or(""));
   auto subgroupSize = uArch->getSubgroupSize();
   VectorType resVecTy = load.getValueType();
-  VectorType maskTy = llvm::dyn_cast<VectorType>(load.getMask().getType());
   int chunkSize = load.getChunkSize().value_or(1);
 
+  LayoutInfo resLayoutInfo = results[0]->getValue();
+  if (!resLayoutInfo.isAssigned())
+    return;
+  auto consumerLayoutAttr =
+      dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+
   if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
     requiredAnchorLayoutAttr = anchorLayoutAttr;
   } else {
@@ -1166,11 +1205,11 @@ void LayoutInfoPropagation::visitLoadGatherOp(
       return;
     }
     requiredAnchorLayoutAttr = xegpu::setupLoadGatherAnchorLayout(
-        layoutKind, resVecTy, chunkSize, uArch);
+        layoutKind, resVecTy, chunkSize, consumerLayoutAttr, uArch);
     load.setLayoutAttr(requiredAnchorLayoutAttr);
   }
 
-  auto maskLayoutAttr = anchorLayoutAttr;
+  auto maskLayoutAttr = requiredAnchorLayoutAttr;
   // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
   // layout for mask.
   if (chunkSize > 1) {
@@ -1186,14 +1225,15 @@ void LayoutInfoPropagation::visitLoadGatherOp(
   }
 
   LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+  auto loadLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
 
   // Propagate the new layout to the tensor descriptor operand.
   if (isa<xegpu::TensorDescType>(load.getSourceType()))
-    propagateIfChanged(operands[0], operands[0]->meet(loadLayout));
+    propagateIfChanged(operands[0], operands[0]->meet(loadLayoutInfo));
   // Propagate the new layout to the mask and optional offset operand.
-  propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
+  propagateIfChanged(operands[1], operands[1]->meet(maskLayoutInfo));
   if (load.getOffsets())
-    propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
+    propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
 }
 
 /// Propagate the layout of the descriptor to the vector offset operand in
@@ -1218,6 +1258,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
 
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Processing store scatter op\n");
+
   xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
   xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
   auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
@@ -1227,9 +1269,19 @@ void LayoutInfoPropagation::visitStoreScatterOp(
       llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
   int chunkSize = storeScatter.getChunkSize().value_or(1);
 
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: anchorLayoutAttr = "
+                    << anchorLayoutAttr << "\n");
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: subgroupSize = " << subgroupSize
+                    << "\n");
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: chunkSize = " << chunkSize
+                    << "\n");
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: srcVecTy = " << srcVecTy << "\n");
+
   if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
+    LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Using existing anchor layout\n");
     requiredAnchorLayoutAttr = anchorLayoutAttr;
   } else {
+    LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Setting up new anchor layout\n");
     if (!srcVecTy) {
       storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
       return;
@@ -1239,11 +1291,16 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
   }
 
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: requiredAnchorLayoutAttr = "
+                    << requiredAnchorLayoutAttr << "\n");
+
   LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
-  auto maskLayoutAttr = anchorLayoutAttr;
+  auto maskLayoutAttr = requiredAnchorLayoutAttr;
   // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
   // layout for mask.
   if (chunkSize > 1) {
+    LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Setting mask layout for chunked "
+                         "operation\n");
     if (layoutKind == xegpu::LayoutKind::InstData)
       maskLayoutAttr =
           xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
@@ -1256,41 +1313,72 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   }
 
   LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: maskLayoutAttr = "
+                    << maskLayoutAttr << "\n");
 
   // Propagate the payload operand layout
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating payload layout\n");
   propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
   // Propagate the destination (if tdesc) operand layout
-  if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
+  if (isa<xegpu::TensorDescType>(storeScatter.getDestType())) {
+    LLVM_DEBUG(
+        DBGS() << "visitStoreScatterOp: Propagating destination layout\n");
     propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
+  }
   // Propagate the new layout to the mask and optional offset operand.
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating mask layout\n");
   propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
-  if (storeScatter.getOffsets())
+  if (storeScatter.getOffsets()) {
+    LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating offset layout\n");
     propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
+  }
+  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Done\n");
 }
 
 void LayoutInfoPropagation::visitLoadMatrixOp(
     xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
 
+  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Processing load matrix op\n");
+
   LayoutInfo resLayoutInfo = results[0]->getValue();
-  if (!resLayoutInfo.isAssigned())
+  if (!resLayoutInfo.isAssigned()) {
+    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Result layout not assigned\n");
     return;
+  }
+
+  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: resLayoutInfo = ";
+             resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
+
   auto consumerLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
 
+  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: consumerLayoutAttr = "
+                    << consumerLayoutAttr << "\n");
+
   xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
 
+  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: anchorLayout = " << anchorLayout
+                    << "\n");
+
   // only need to set anchor layout, no need to porpagate to memdesc and
   // offset
   if (!hasParamsOfLayoutKind(anchorLayout)) {
+    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Setting up new anchor layout\n");
     VectorType resVecTy =
         llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
     assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
+    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: resVecTy = " << resVecTy << "\n");
     auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
-    auto requiredAnchorLayoutAttr = xegpu::loadMatrixSetupAnchorLayout(
+    auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
         layoutKind, resVecTy, consumerLayoutAttr, uArch);
+    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: requiredAnchorLayoutAttr = "
+                      << requiredAnchorLayoutAttr << "\n");
     loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
+  } else {
+    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Using existing anchor layout\n");
   }
+  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Done\n");
 }
 
 // Store matrix is a flavor of scattered store for 2D shapes.
@@ -1308,7 +1396,7 @@ void LayoutInfoPropagation::visitStoreMatrixOp(
     assert(srcVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
     auto uArch = getUArch(getChipStr(storeMatrix).value_or(""));
     auto requiredAnchorLayoutAttr =
-        xegpu::storeMatrixSetupAnchorLayout(layoutKind, srcVecTy, uArch);
+        xegpu::setupStoreMatrixAnchorLayout(layoutKind, srcVecTy, uArch);
     storeMatrix.setLayoutAttr(requiredAnchorLayoutAttr);
     layout = LayoutInfo(requiredAnchorLayoutAttr);
   }
@@ -1591,6 +1679,12 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
     if (!layout.isAssigned())
       return {};
     if (auto opResult = dyn_cast<OpResult>(val)) {
+
+      Operation *defOp = opResult.getDefiningOp();
+      if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
+        return anchorOp.getAnchorLayout();
+      }
+
       xegpu::DistributeLayoutAttr requiredResLayoutAttr =
           xegpu::getTemporaryLayout(opResult);
       if (requiredResLayoutAttr != nullptr)
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 8d7d1b96b363e..50f2caaaba1ea 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -362,8 +362,8 @@ xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 /// Infers the source layout attribute for a reduction operation given the
 /// result layout attribute and reduced dims.
 xegpu::DistributeLayoutAttr
-xegpu::inferReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
-                                  SmallVector<int64_t> reduceDims) {
+xegpu::inferMultiReductionSourceLayout(xegpu::DistributeLayoutAttr resLayout,
+                                       SmallVector<int64_t> reduceDims) {
 
   //  assert the resLayout must be slice layout
   assert(isa<xegpu::SliceAttr>(resLayout) &&
@@ -532,61 +532,6 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   return nullptr;
 }
 
-xegpu::DistributeLayoutAttr
-inferScatterIOMaskLayout(xegpu::DistributeLayoutAttr resLayout,
-                         ArrayRef<int64_t> valShape,
-                         ArrayRef<int64_t> maskShape) {
-
-  xegpu::LayoutAttr resPlainLayout = dyn_cast<xegpu::LayoutAttr>(resLayout);
-  assert(resPlainLayout &&
-         "Expecting plain layout for scatter IO mask inference.");
-  xegpu::LayoutAttr maskLayout = resPlainLayout;
-
-  if (maskShape.size() < valShape.size()) {
-    int valShapeSize = valShape.size();
-    int maskShapeSize = maskShape.size();
-    assert((maskShapeSize == (valShapeSize - 1)) &&
-           "Expecting mask vector only 1 dimension less than value vector.");
-
-    maskLayout = resPlainLayout.setUnitDimData(valShapeSize - 1);
-    maskLayout = maskLayout.collapseDim(valShapeSize - 1);
-    // SmallVector<int> sgLayout(valShapeSize);
-    // sgLayout = resPlainLayout.getSgLayout();
-    // SmallVector<int> sgData(valShapeSize) = resPlainLayout.getSgData();
-    // SmallVector<int> instData(valShapeSize) = resPlainLayout.getSgLayout();
-    // SmallVector<int> laneLayout(valShapeSize) =
-    // resPlainLayout.getLaneLayout(); SmallVector<int> laneData(valShapeSize) =
-    // resPlainLayout.getSgLayout(); SmallVector<int> order(valShapeSize) =
-    // resPlainLayout.getOrder();
-
-    // // drop the innermost dimension for mask layout
-    // sgLayout.pop_back();
-    // sgData.pop_back();
-    // instData.pop_back();
-    // laneLayout.pop_back();
-    // laneData.pop_back();
-    // order.pop_back();
-
-    // SmallVector<int64_t> remappedOrder(maskShapeSize, -1);
-    // for (int64_t i = 0; i < maskShapeSize; ++i) {
-    //   int64_t originalOrderValue = order[i];
-    //   // count how many values in collapsed Order are less than
-    //   // originalOrderValue
-    //   int64_t count = 0;
-    //   for (int64_t j = 0; j < maskShapeSize; ++j) {
-    //     if (order[j] < originalOrderValue)
-    //       count++;
-    //   }
-    //   remappedOrder[i] = count;
-    // }
-
-    // maskLayout =
-    //     xegpu::LayoutAttr::get(resLayout.getContext(), sgLayout, sgData,
-    //                            instData, laneLayout, laneData, order);
-  }
-  return maskLayout;
-}
-
 /// Sets up layout for reduction operations by creating a SliceAttr for the
 /// result.
 ///
@@ -968,15 +913,28 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
     LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
 
+  llvm::dbgs() << "setupLoadGatherAnchorLayout: layoutKind="
+               << static_cast<int>(layoutKind) << ", chunkSize=" << chunkSize
+               << "\n";
+
   xegpu::DistributeLayoutAttr requiredLayout;
   const int subgroupSize = uArch->getSubgroupSize();
   const int spirVectorSize = 16; // vector size from SPRIV vector restriction
 
   auto resShape = resVecTy.getShape();
   int resShapeSize = resShape.size();
-  SmallVector<int64_t> instData(subgroupSize);
+  SmallVector<int> instData(resShapeSize);
   auto context = resVecTy.getContext();
 
+  llvm::dbgs() << "setupLoadGatherAnchorLayout: resVecTy.getRank()="
+               << resVecTy.getRank() << ", resShape=[";
+  for (size_t i = 0; i < resShape.size(); ++i) {
+    if (i > 0)
+      llvm::dbgs() << ", ";
+    llvm::dbgs() << resShape[i];
+  }
+  llvm::dbgs() << "]\n";
+
   const auto *uArchInstruction =
       dyn_cast<xegpu::uArch::StoreScatterInstruction>(
           uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
@@ -985,32 +943,48 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
       consumerLayout.getEffectiveInstDataAsInt();
   SmallVector<int32_t> instData32;
 
+  llvm::dbgs() << "setupLoadGatherAnchorLayout: consumerInstData=[";
+  for (size_t i = 0; i < consumerInstData.size(); ++i) {
+    if (i > 0)
+      llvm::dbgs() << ", ";
+    llvm::dbgs() << consumerInstData[i];
+  }
+  llvm::dbgs() << "]\n";
+
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
+    llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::Subgroup\n";
     requiredLayout = consumerLayout;
     break;
   case xegpu::LayoutKind::InstData:
+    llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::InstData\n";
     if (resVecTy.getRank() == 1) {
       instData[0] = subgroupSize;
+      llvm::dbgs() << "setupLoadGatherAnchorLayout: 1D case, instData[0]="
+                   << instData[0] << "\n";
     } else {
-      assert((resVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
-                                         "tile at maximum at subgroup level.");
+      assert((resVecTy.getRank() == 2) && "StoreScatterOp can access 2D tensor "
+                                          "tile at maximum at subgroup level.");
       if (chunkSize == 1) {
         instData[0] = 1;
         instData[1] = subgroupSize;
+        llvm::dbgs() << "setupLoadGatherAnchorLayout: chunkSize==1, instData=["
+                     << instData[0] << ", " << instData[1] << "]\n";
       } else {
         instData[0] = subgroupSize;
-        instData[1] = std::min(
-            resShape[1],
-            static_cast<int64_t>(uArchInstruction->getMaxLaneLoadStoreSize()));
-        instData[1] = std::min(instData[1], consumerInstData[1]);
+        instData[1] = std::min(static_cast<int>(resShape[1]),
+                               uArchInstruction->getMaxLaneLoadStoreSize());
+        instData[1] =
+            std::min(instData[1], static_cast<int>(consumerInstData[1]));
+        llvm::dbgs() << "setupLoadGatherAnchorLayout: chunkSize>1, instData=["
+                     << instData[0] << ", " << instData[1] << "]\n";
       }
     }
-    instData32 = SmallVector<int32_t>(instData.begin(), instData.end());
     requiredLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, instData32));
+        context, DenseI32ArrayAttr::get(context, instData));
     break;
   case xegpu::LayoutKind::Lane:
+    llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::Lane\n";
     if (chunkSize == 1)
       requiredLayout =
           getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
@@ -1027,6 +1001,7 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
   default:
     llvm_unreachable("unsupported layout kind");
   }
+  llvm::dbgs() << "setupLoadGatherAnchorLayout: returning requiredLayout\n";
   return requiredLayout;
 }
 
@@ -1034,13 +1009,26 @@ xegpu::DistributeLayoutAttr
 xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
                                      int chunkSize, const uArch::uArch *uArch) {
 
+  llvm::dbgs() << "setupStoreScatterAnchorLayout: layoutKind="
+               << static_cast<int>(layoutKind) << ", chunkSize=" << chunkSize
+               << "\n";
+
   xegpu::DistributeLayoutAttr requiredLayout;
   const int subgroupSize = uArch->getSubgroupSize();
   const int spirVectorSize = 16; // vector size from SPRIV vector restriction
 
   auto srcShape = srcVecTy.getShape();
   int srcShapeSize = srcShape.size();
-  SmallVector<int64_t> instData(subgroupSize);
+  SmallVector<int> instData(srcShapeSize);
+
+  llvm::dbgs() << "setupStoreScatterAnchorLayout: srcVecTy.getRank()="
+               << srcVecTy.getRank() << ", srcShape=[";
+  for (size_t i = 0; i < srcShape.size(); ++i) {
+    if (i > 0)
+      llvm::dbgs() << ", ";
+    llvm::dbgs() << srcShape[i];
+  }
+  llvm::dbgs() << "]\n";
 
   const auto *uArchInstruction =
       dyn_cast<xegpu::uArch::StoreScatterInstruction>(
@@ -1049,27 +1037,39 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
 
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
-    assert(false &&
-           "subgroup layout assignment not supported yet for loadMatrix.");
+    llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::Subgroup\n";
+    assert(
+        false &&
+        "subgroup layout assignment not supported yet for store scatter op.");
     break;
   case xegpu::LayoutKind::InstData:
+    llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::InstData\n";
     if (srcVecTy.getRank() == 1) {
       instData[0] = subgroupSize;
+      llvm::dbgs() << "setupStoreScatterAnchorLayout: 1D case, instData[0]="
+                   << instData[0] << "\n";
     } else {
       assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
                                           "tile at maximum at subgroup level.");
       if (chunkSize == 1) {
         instData[0] = 1;
         instData[1] = subgroupSize;
+        llvm::dbgs()
+            << "setupStoreScatterAnchorLayout: chunkSize==1, instData=["
+            << instData[0] << ", " << instData[1] << "]\n";
       } else {
         instData[0] = subgroupSize;
-        instData[1] = std::min(
-            srcShape[1],
-            static_cast<int64_t>(uArchInstruction->getMaxLaneLoadStoreSize()));
+        instData[1] = std::min(static_cast<int>(srcShape[1]),
+                               uArchInstruction->getMaxLaneLoadStoreSize());
+        llvm::dbgs() << "setupStoreScatterAnchorLayout: chunkSize>1, instData=["
+                     << instData[0] << ", " << instData[1] << "]\n";
       }
     }
+    requiredLayout = xegpu::LayoutAttr::get(
+        context, DenseI32ArrayAttr::get(context, instData));
     break;
   case xegpu::LayoutKind::Lane:
+    llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::Lane\n";
     if (chunkSize == 1)
       requiredLayout =
           getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
@@ -1086,5 +1086,6 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
   default:
     llvm_unreachable("unsupported layout kind");
   }
+  llvm::dbgs() << "setupStoreScatterAnchorLayout: returning requiredLayout\n";
   return requiredLayout;
 }
\ No newline at end of file

>From c7f1c8603d45d68e26b82b219a9f38aa5c60f06a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 28 Jan 2026 20:01:57 +0000
Subject: [PATCH 15/35] adding insert_strided_slice and improve shapecast rules

---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.h    |  18 +-
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |   4 +-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  52 +++++-
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 166 ++++++++++++++----
 4 files changed, 194 insertions(+), 46 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
index 95122104f22e5..110496bb34fb3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
@@ -90,14 +90,19 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
 
+DistributeLayoutAttr
+inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
+                                    ArrayRef<int64_t> resShape,
+                                    ArrayRef<int64_t> srcShape);
+
 /// Sets up layout for reduction operations by creating a SliceAttr for the
 /// result.
 ///
-/// This function first attempts to construct a source layout that, when sliced
-/// along reduction dimensions, produces a result layout compatible with the
-/// consumer's preferred layout. This minimizes data redistribution overhead.
-/// The SliceAttr for the result is then created based on the derived source
-/// layout and the specified reduction dimensions.
+/// This function first attempts to construct a source layout that, when
+/// sliced along reduction dimensions, produces a result layout compatible
+/// with the consumer's preferred layout. This minimizes data redistribution
+/// overhead. The SliceAttr for the result is then created based on the
+/// derived source layout and the specified reduction dimensions.
 SliceAttr setupMultiReductionResultLayout(xegpu::LayoutKind layoutKind,
                                           VectorType srcVectorTy,
                                           DistributeLayoutAttr consumerLayout,
@@ -117,6 +122,9 @@ DistributeLayoutAttr setupBitCastResultLayout(
     LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
+DistributeLayoutAttr setupInsertStridedSliceResultLayout(
+    LayoutKind layoutKind, VectorType resVectorTy, const uArch::uArch *uArch);
+
 DistributeLayoutAttr
 setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
                             DistributeLayoutAttr consumerLayout,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7af7622375ef4..29cd74ad461dd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -162,8 +162,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
           ownerOp &&
           (isa<xegpu::CreateNdDescOp, xegpu::DpasOp, xegpu::ConvertLayoutOp,
                xegpu::LoadMatrixOp, xegpu::StoreMatrixOp, xegpu::AtomicRMWOp,
-               xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
-               vector::TransposeOp, vector::ShapeCastOp,
+               xegpu::LoadGatherOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+               xegpu::PrefetchNdOp, vector::TransposeOp, vector::ShapeCastOp,
                vector::MultiDimReductionOp, vector::BroadcastOp>(ownerOp));
       if (!skipLeadingUnitDimRemoval) {
         auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 44127d27de522..5cb0e259c503b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -418,6 +418,10 @@ class LayoutInfoPropagation
   void visitShapeCastOp(vector::ShapeCastOp shapeCast,
                         ArrayRef<LayoutInfoLattice *> operands,
                         ArrayRef<const LayoutInfoLattice *> results);
+  void
+  visitInsertStridedSliceOp(vector::InsertStridedSliceOp insertStridedSlice,
+                            ArrayRef<LayoutInfoLattice *> operands,
+                            ArrayRef<const LayoutInfoLattice *> results);
 
   void visitLoadMatrixOp(xegpu::LoadMatrixOp load,
                          ArrayRef<LayoutInfoLattice *> operands,
@@ -508,6 +512,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
       .Case<vector::ShapeCastOp>([&](auto shapeCastOp) {
         visitShapeCastOp(shapeCastOp, operands, results);
       })
+      .Case<vector::InsertStridedSliceOp>([&](auto insertStridedSliceOp) {
+        visitInsertStridedSliceOp(insertStridedSliceOp, operands, results);
+      })
       .Case<xegpu::LoadMatrixOp>([&](auto loadMatrixOp) {
         visitLoadMatrixOp(loadMatrixOp, operands, results);
       })
@@ -795,7 +802,6 @@ void LayoutInfoPropagation::visitUpdateNdOffsetOp(
 void LayoutInfoPropagation::visitDpasOp(
     xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-
   LayoutInfo dpasALayout;
   LayoutInfo dpasBLayout;
   LayoutInfo dpasCDLayout;
@@ -992,7 +998,6 @@ void LayoutInfoPropagation::visitDpasOp(
 void LayoutInfoPropagation::visitStoreNdOp(
     xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-
   LayoutInfo storeLayout;
   xegpu::DistributeLayoutAttr anchorLayout = store.getLayoutAttr();
   if (hasParamsOfLayoutKind(anchorLayout)) {
@@ -1073,7 +1078,6 @@ void LayoutInfoPropagation::visitStoreNdOp(
 void LayoutInfoPropagation::visitLoadNdOp(
     xegpu::LoadNdOp load, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-
   LayoutInfo loadLayout;
   xegpu::DistributeLayoutAttr anchorLayout = load.getLayoutAttr();
   if (hasParamsOfLayoutKind(anchorLayout)) {
@@ -1144,6 +1148,37 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
 }
 
+void LayoutInfoPropagation::visitInsertStridedSliceOp(
+    vector::InsertStridedSliceOp insertStridedSlice,
+    ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+  // The layout of the result must be present.
+  LayoutInfo resLayoutInfo = results[0]->getValue();
+  if (!resLayoutInfo.isAssigned())
+    return;
+
+  auto srcVecType = insertStridedSlice.getSourceVectorType();
+  auto resVecType = insertStridedSlice.getDestVectorType();
+
+  auto consumerLayoutAttr =
+      dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
+  auto uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
+
+  auto requiredResLayoutAttr =
+      xegpu::setupInsertStridedSliceResultLayout(layoutKind, resVecType, uArch);
+
+  xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
+                            requiredResLayoutAttr);
+
+  auto srcLayoutAttr = xegpu::inferInsertStridedSliceSourceLayout(
+      requiredResLayoutAttr, resVecType.getShape(), srcVecType.getShape());
+
+  propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
+  propagateIfChanged(operands[1],
+                     operands[1]->meet(LayoutInfo(requiredResLayoutAttr)));
+  return;
+}
+
 // /// For vector::BitCastOp, the lane_data of the source layout is changed
 // based
 // /// on the bit width of the source and result types.
@@ -1183,7 +1218,6 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
 void LayoutInfoPropagation::visitLoadGatherOp(
     xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-
   xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
   xegpu::DistributeLayoutAttr anchorLayoutAttr = load.getLayoutAttr();
   auto uArch = getUArch(getChipStr(load).value_or(""));
@@ -1257,7 +1291,6 @@ void LayoutInfoPropagation::visitCreateDescOp(
 void LayoutInfoPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-
   LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Processing store scatter op\n");
 
   xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
@@ -1338,7 +1371,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
 void LayoutInfoPropagation::visitLoadMatrixOp(
     xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-
   LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Processing load matrix op\n");
 
   LayoutInfo resLayoutInfo = results[0]->getValue();
@@ -1385,7 +1417,6 @@ void LayoutInfoPropagation::visitLoadMatrixOp(
 void LayoutInfoPropagation::visitStoreMatrixOp(
     xegpu::StoreMatrixOp storeMatrix, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-
   xegpu::DistributeLayoutAttr anchorLayout = storeMatrix.getLayoutAttr();
   LayoutInfo layout;
   if (hasParamsOfLayoutKind(anchorLayout)) {
@@ -1681,7 +1712,14 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
     if (auto opResult = dyn_cast<OpResult>(val)) {
 
       Operation *defOp = opResult.getDefiningOp();
+      LLVM_DEBUG(DBGS() << "Try op: " << *defOp << "\n");
       if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
+        // inject debug print here
+        LLVM_DEBUG(DBGS() << "AnchorLayoutInterface found for op: " << *defOp
+                          << "\n");
+        // print the anchor layout
+        LLVM_DEBUG(DBGS() << "Anchor layout: " << anchorOp.getAnchorLayout()
+                          << "\n");
         return anchorOp.getAnchorLayout();
       }
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 50f2caaaba1ea..9cfcad834a924 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -434,9 +434,9 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   // 2. split dim of a high-rank dimension (e.g., 1D to 2D): to setup tensor
   // for multi-stage reduction
   // 3. combines all dims to a single dim and put in the innermost dim in 2d as
-  // [1, combinedData]. only used after workgroup distribution to save
-  // multidimension data to 1D slm buffer so no need to handle sg_layout and
-  // sg_data.
+  // [1, combinedData] or [combinedData]. Only used after workgroup
+  // distribution. Example like cross-sg reduction saves multidimension data to
+  // 1D slm buffer, shapecast inserted by cse/canonicalization passes.
 
   // Use case 1: Check if shapes only differ by expanding unit dimensions (like
   // expand_dims)
@@ -505,27 +505,47 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   auto checkCombineToInnerMostDim = [&](ArrayRef<int64_t> src,
                                         ArrayRef<int64_t> dst) -> bool {
     // only one non-unit dim in dst which is the innermost dim
-    assert((dst.size() == 2) && "dst shape must be 2D");
+    if ((dst.size() != 2) && (dst.size() != 1))
+      return false;
     int64_t srcSize = std::accumulate(src.begin(), src.end(), 1LL,
                                       std::multiplies<int64_t>());
+    if (dst.size() == 1)
+      return (dst[0] == srcSize);
     return (dst[0] == 1) && (dst[1] == srcSize);
   };
 
   if (checkCombineToInnerMostDim(srcShape, resShape)) {
     const int subgroupSize = 16; // assuming 16 lanes per subgroup
-    const int vectorSize = 8;    // assuming 8 elements per vector lane
     int srcShapeSize = srcShape.size();
+    auto context = resLayout.getContext();
+    auto resInstData = resLayout.getEffectiveInstDataAsInt();
+    auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+    auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+    if (resInstData.size() != 0) {
+      if (resInstData.size() == 2)
+        assert(resInstData[0] == 1 &&
+               "only innermost dim can have inst_data for combine-to-1d");
+      // construct source inst_data layout like [1, ..., 1, subgroupSize]
+      SmallVector<int> inferredInstData(srcShapeSize, 1);
+      inferredInstData[srcShapeSize - 1] = resInstData[resInstData.size() - 1];
+      return xegpu::LayoutAttr::get(context, inferredInstData);
+    }
 
-    SmallVector<int64_t> instData(srcShapeSize, 1);
-    instData[srcShapeSize - 1] = subgroupSize;
-    instData[srcShapeSize - 2] =
-        vectorSize; // assuming 8 elements per instruction as starting point
-
-    // construct a vector layout with lane_layout = [1, ..., 1, subgroupSize]
-    SmallVector<int64_t> laneLayout(srcShapeSize, 1);
-    laneLayout[srcShapeSize - 1] = subgroupSize;
-    // construct a vector layout with lane_data = [1, ..., 1]
-    SmallVector<int64_t> laneData(srcShapeSize, 1);
+    if (resLaneLayout.size() != 0) {
+      if (resInstData.size() == 2)
+        assert(resInstData[0] == 1 &&
+               "only innermost dim can have inst_data for combine-to-1d");
+      // construct source lane_layout like [1, ..., 1, subgroupSize]
+      SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
+      SmallVector<int> inferredLaneData(srcShapeSize, 1);
+      inferredLaneLayout[srcShapeSize - 1] =
+          resLaneLayout[resLaneLayout.size() - 1];
+
+      inferredLaneData[srcShapeSize - 1] =
+          resLaneLayout[resLaneLayout.size() - 1];
+      return xegpu::LayoutAttr::get(context, inferredLaneLayout,
+                                    inferredLaneData);
+    }
   }
 
   // TODO: Complete implementation for other shape cast scenarios
@@ -846,24 +866,58 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
   return resLayout;
 }
 
+xegpu::DistributeLayoutAttr
+xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
+                                   VectorType resVectorTy,
+                                   xegpu::DistributeLayoutAttr consumerLayout,
+                                   const xegpu::uArch::uArch *uArch) {
+  xegpu::DistributeLayoutAttr requiredLayout;
+  auto subgroupSize = uArch->getSubgroupSize();
+  SmallVector<int> defaultInstData = {1, subgroupSize};
+  SmallVector<int> defaultLaneLayout = {1, subgroupSize};
+  SmallVector<int> defaultLaneData = {1, 1};
+  auto context = resVectorTy.getContext();
+
+  switch (layoutKind) {
+  case xegpu::LayoutKind::Subgroup:
+    requiredLayout = consumerLayout;
+    break;
+  case xegpu::LayoutKind::InstData:
+    requiredLayout = xegpu::LayoutAttr::get(context, defaultInstData);
+    break;
+  case xegpu::LayoutKind::Lane:
+    requiredLayout =
+        xegpu::LayoutAttr::get(context, defaultLaneData, defaultLaneLayout);
+    break;
+  default:
+    llvm_unreachable("unsupported layout kind");
+  }
+  return requiredLayout;
+}
+
 xegpu::DistributeLayoutAttr
 xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
-                                    VectorType vectorTy,
+                                    VectorType srcVectorTy,
                                     const xegpu::uArch::uArch *uArch) {
 
   xegpu::DistributeLayoutAttr requiredLayout;
-  SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+  auto subgroupSize = uArch->getSubgroupSize();
+  SmallVector<int> defaultInstData = {1, subgroupSize};
+  SmallVector<int> defaultLaneLayout = {1, subgroupSize};
+  SmallVector<int> defaultLaneData = {1, 1};
+  auto context = srcVectorTy.getContext();
+
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
     assert(true &&
            "subgroup layout assignment not supported yet for storeMatrix.");
     break;
   case xegpu::LayoutKind::InstData:
-    requiredLayout = xegpu::LayoutAttr::get(vectorTy.getContext(), instData);
+    requiredLayout = xegpu::LayoutAttr::get(context, defaultInstData);
     break;
   case xegpu::LayoutKind::Lane:
-    requiredLayout = xegpu::LayoutAttr::get(
-        vectorTy.getContext(), {1, uArch->getSubgroupSize()}, {1, 1});
+    requiredLayout =
+        xegpu::LayoutAttr::get(context, defaultLaneData, defaultLaneLayout);
 
     break;
   default:
@@ -873,30 +927,78 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
 }
 
 xegpu::DistributeLayoutAttr
-xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
-                                   VectorType vectorTy,
-                                   xegpu::DistributeLayoutAttr consumerLayout,
-                                   const xegpu::uArch::uArch *uArch) {
-  xegpu::DistributeLayoutAttr requiredLayout;
-  SmallVector<int> instData = {1, uArch->getSubgroupSize()};
+xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
+                                           VectorType resVectorTy,
+                                           const xegpu::uArch::uArch *uArch) {
+
+  xegpu::DistributeLayoutAttr requiredResLayout;
+  auto subgroupSize = uArch->getSubgroupSize();
+  auto context = resVectorTy.getContext();
+  auto resShape = resVectorTy.getShape();
+  int resShapeSize = resShape.size();
+  SmallVector<int> defaultInstData(resShapeSize, 1);
+  SmallVector<int> defaultLaneLayout(resShapeSize, 1);
+  SmallVector<int> defaultLaneData(resShapeSize, 1);
 
-  auto context = vectorTy.getContext();
+  defaultInstData[resShapeSize - 1] = subgroupSize;
+  defaultLaneLayout[resShapeSize - 1] = subgroupSize;
 
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
-    requiredLayout = consumerLayout;
+    assert(true &&
+           "subgroup layout assignment not supported for insertStridedSlice.");
     break;
   case xegpu::LayoutKind::InstData:
-    requiredLayout = xegpu::LayoutAttr::get(context, instData);
+    requiredResLayout = xegpu::LayoutAttr::get(context, defaultInstData);
     break;
   case xegpu::LayoutKind::Lane:
-    requiredLayout =
-        xegpu::LayoutAttr::get(context, {1, uArch->getSubgroupSize()}, {1, 1});
+    requiredResLayout =
+        xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
     break;
   default:
     llvm_unreachable("unsupported layout kind");
   }
-  return requiredLayout;
+  return requiredResLayout;
+}
+
+xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
+    xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
+    ArrayRef<int64_t> srcShape) {
+
+  const int subgroupSize = 16; // assuming 16 lanes per subgroup
+  int srcShapeSize = srcShape.size();
+  int resShapeSize = resShape.size();
+  int dimDiff = resShapeSize - srcShapeSize;
+
+  // assert resLayout must be a plain layout
+  assert(isa<xegpu::LayoutAttr>(resLayout) &&
+         "insertStridedSlice result layout must be plain layout");
+  auto context = resLayout.getContext();
+  auto resInstData = resLayout.getEffectiveInstDataAsInt();
+  auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+  auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+
+  if (resInstData.size() != 0) {
+    SmallVector<int> inferredInstData(srcShapeSize);
+    // remove the initial dims in resInstData to match srcShapeSize
+    for (int i = 0; i < srcShapeSize; i++)
+      inferredInstData[i] = resInstData[i + dimDiff];
+    return xegpu::LayoutAttr::get(context, inferredInstData);
+  }
+
+  if (resLaneLayout.size() != 0) {
+    // construct source lane_layout like [1, ..., 1, subgroupSize]
+    SmallVector<int> inferredLaneLayout(srcShapeSize);
+    SmallVector<int> inferredLaneData(srcShapeSize);
+    // remove the initial dims in resInstData to match srcShapeSize
+    for (int i = 0; i < srcShapeSize; i++) {
+      inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
+      inferredLaneData[i] = resLaneData[i + dimDiff];
+    }
+    return xegpu::LayoutAttr::get(context, inferredLaneLayout,
+                                  inferredLaneData);
+  }
+  return nullptr;
 }
 
 static xegpu::DistributeLayoutAttr

>From 1c50dd3514c61ff927279032b4120d199cea65f7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 00:14:27 +0000
Subject: [PATCH 16/35] add sg distribution for creatmemdesc, able to
 distribute

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  4 +-
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 73 ++++++++++++++++---
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  |  4 +-
 3 files changed, 66 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 29cd74ad461dd..7af7622375ef4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -162,8 +162,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
           ownerOp &&
           (isa<xegpu::CreateNdDescOp, xegpu::DpasOp, xegpu::ConvertLayoutOp,
                xegpu::LoadMatrixOp, xegpu::StoreMatrixOp, xegpu::AtomicRMWOp,
-               xegpu::LoadGatherOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
-               xegpu::PrefetchNdOp, vector::TransposeOp, vector::ShapeCastOp,
+               xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp,
+               vector::TransposeOp, vector::ShapeCastOp,
                vector::MultiDimReductionOp, vector::BroadcastOp>(ownerOp));
       if (!skipLeadingUnitDimRemoval) {
         auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 8898be5f13dab..80fd5feea9f97 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1619,14 +1619,16 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
     // must be a slice of higher rank layout.
     int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
     int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
-    if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout))
-      return rewriter.notifyMatchFailure(
-          warpOp, "shape_cast is rank reducing but source layout is not a "
-                  "slice of result layout");
-    if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout))
-      return rewriter.notifyMatchFailure(
-          warpOp, "shape_cast is rank increasing but result layout is not a "
-                  "slice of source layout");
+    // if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout)) {
+    //   return rewriter.notifyMatchFailure(
+    //       warpOp, "shape_cast is rank reducing but source layout is not a "
+    //               "slice of result layout");
+    // }
+    // if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout)) {
+    //   return rewriter.notifyMatchFailure(
+    //       warpOp, "shape_cast is rank increasing but result layout is not a "
+    //               "slice of source layout");
+    // }
 
     FailureOr<VectorType> sourceDistTypeOrFailure =
         getDistVecTypeBasedOnLaneLayout(sourceLayout,
@@ -1906,8 +1908,56 @@ struct MemrefExtractAlignedPointerAsIndexDistribution final
     auto newExtractOp = memref::ExtractAlignedPointerAsIndexOp::create(
         rewriter, newWarpOp.getLoc(), extractOp.getType(),
         newWarpOp.getResult(newRetIndices[0]));
-    Value distributedVal = newWarpOp.getResult(operandIdx);
-    rewriter.replaceAllUsesWith(distributedVal, newExtractOp.getResult());
+    Value resultVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(resultVal, newExtractOp.getResult());
+    return success();
+  }
+};
+
+struct MemrefAllocaDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<memref::AllocaOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a memref::Alloca op");
+    auto allocaOp = operand->get().getDefiningOp<memref::AllocaOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, ValueRange{}, TypeRange{}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newAllocaOp = memref::AllocaOp::create(rewriter, newWarpOp.getLoc(),
+                                                allocaOp.getType(), nullptr);
+    Value resultVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(resultVal, newAllocaOp.getResult());
+    return success();
+  }
+};
+
+struct CreateMemDescDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *operand =
+        getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateMemDescOp>);
+    if (!operand)
+      return rewriter.notifyMatchFailure(
+          warpOp, "warp result is not a xegpu::CreateMemDesc op");
+    auto createMemDescOp =
+        operand->get().getDefiningOp<xegpu::CreateMemDescOp>();
+    unsigned operandIdx = operand->getOperandNumber();
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, createMemDescOp.getSource(),
+        TypeRange{createMemDescOp.getSource().getType()}, newRetIndices);
+    rewriter.setInsertionPointAfter(newWarpOp);
+    auto newCreateMemDescOp = xegpu::CreateMemDescOp::create(
+        rewriter, newWarpOp.getLoc(), createMemDescOp.getType(),
+        newWarpOp.getResult(newRetIndices[0]));
+    Value resultVal = newWarpOp.getResult(operandIdx);
+    rewriter.replaceAllUsesWith(resultVal, newCreateMemDescOp.getResult());
     return success();
   }
 };
@@ -2035,7 +2085,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
                LoadDistribution, StoreDistribution, VectorTransposeDistribution,
                VectorBitcastDistribution, LoadMatrixDistribution,
                StoreMatrixDistribution,
-               MemrefExtractAlignedPointerAsIndexDistribution>(
+               MemrefExtractAlignedPointerAsIndexDistribution,
+               MemrefAllocaDistribution, CreateMemDescDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/regularPatternBenefit);
   // For following patterns, we need to override the regular vector distribution
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 9cfcad834a924..225a3874fb2ae 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -887,7 +887,7 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
     break;
   case xegpu::LayoutKind::Lane:
     requiredLayout =
-        xegpu::LayoutAttr::get(context, defaultLaneData, defaultLaneLayout);
+        xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
     break;
   default:
     llvm_unreachable("unsupported layout kind");
@@ -917,7 +917,7 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
     break;
   case xegpu::LayoutKind::Lane:
     requiredLayout =
-        xegpu::LayoutAttr::get(context, defaultLaneData, defaultLaneLayout);
+        xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
 
     break;
   default:

>From a7038d260f86b170d4310cba38fa51403c3a0ee3 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 01:39:05 +0000
Subject: [PATCH 17/35] xegpu to xevm type convert fix

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8a06271eadd84..8efbb0702f0d3 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,7 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (vecTy.getNumElements() == 1) {
+        if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
           // If the vector has a single element, return the element type.
           Value cast =
               vector::ExtractOp::create(builder, loc, input, 0).getResult();

>From 53bc42bf29d4e0b2f3e54ce64305a5d0e7f4e13b Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 21:48:40 +0000
Subject: [PATCH 18/35] fix bitcast, reduction propagation rules, passing lane
 layout tests

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  28 ++-
 .../Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp  | 188 ++++++++++++------
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  46 ++---
 3 files changed, 173 insertions(+), 89 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 5cb0e259c503b..4966c657c6eca 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1714,24 +1714,40 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
       Operation *defOp = opResult.getDefiningOp();
       LLVM_DEBUG(DBGS() << "Try op: " << *defOp << "\n");
       if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
-        // inject debug print here
         LLVM_DEBUG(DBGS() << "AnchorLayoutInterface found for op: " << *defOp
                           << "\n");
-        // print the anchor layout
-        LLVM_DEBUG(DBGS() << "Anchor layout: " << anchorOp.getAnchorLayout()
+        auto anchorLayout = anchorOp.getAnchorLayout();
+        LLVM_DEBUG(DBGS() << "Anchor layout: " << anchorLayout << "\n");
+        LLVM_DEBUG(DBGS() << "Anchor layout is null: "
+                          << (anchorLayout == nullptr ? "true" : "false")
                           << "\n");
-        return anchorOp.getAnchorLayout();
+        if (anchorLayout != nullptr) {
+          LLVM_DEBUG(DBGS()
+                     << "Returning anchor layout: " << anchorLayout << "\n");
+          return anchorLayout;
+        }
+        LLVM_DEBUG(DBGS() << "Anchor layout is null, continuing...\n");
       }
 
       xegpu::DistributeLayoutAttr requiredResLayoutAttr =
           xegpu::getTemporaryLayout(opResult);
-      if (requiredResLayoutAttr != nullptr)
+      LLVM_DEBUG(DBGS() << "Temporary layout for value: " << val << " is "
+                        << requiredResLayoutAttr << "\n");
+      if (requiredResLayoutAttr != nullptr) {
+        LLVM_DEBUG(DBGS() << "Returning temporary layout: "
+                          << requiredResLayoutAttr << "\n");
         return requiredResLayoutAttr;
+      }
     }
     xegpu::DistributeLayoutAttr layoutAttr =
         cast<xegpu::DistributeLayoutAttr>(layout.get());
-    if (layout.isSliceLayout())
+    LLVM_DEBUG(DBGS() << "Layout attr for value: " << val << " is "
+                      << layoutAttr << "\n");
+    if (layout.isSliceLayout()) {
+      LLVM_DEBUG(DBGS() << "Returning slice layout: " << layoutAttr << "\n");
       return cast<xegpu::SliceAttr>(layoutAttr);
+    }
+    LLVM_DEBUG(DBGS() << "Returning layout attr: " << layoutAttr << "\n");
     return cast<xegpu::LayoutAttr>(layoutAttr);
   };
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
index 225a3874fb2ae..86c328616b92a 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
@@ -389,28 +389,63 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   // only adjust the sg_data, inst_data, lane_data accordingly
   // based on the bitwidth ratio between source and result element type
 
+  llvm::dbgs() << "inferBitCastSourceLayout: resElemTyBitWidth="
+               << resElemTyBitWidth
+               << ", srcElemTyBitWidth=" << srcElemTyBitWidth << "\n";
+
   SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
   SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
   SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
-  size_t dim = sgData.size() - 1;
-  int64_t sgDataValue, instDataValue, laneDataValue;
-
-  if (srcElemTyBitWidth >= resElemTyBitWidth) {
-    int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
-    sgDataValue = (dim < sgData.size()) ? sgData[dim] * bitWidthRatio : -1;
-    instDataValue =
-        (dim < instData.size()) ? instData[dim] * bitWidthRatio : -1;
-    laneDataValue =
-        (dim < laneData.size()) ? laneData[dim] * bitWidthRatio : -1;
-  } else {
+  size_t sgDataSize = sgData.size();
+  size_t instDataSize = instData.size();
+  size_t laneDataSize = laneData.size();
+  int64_t sgDataValue = -1;
+  int64_t instDataValue = -1;
+  int64_t laneDataValue = -1;
+  int64_t dim = resLayout.getRank() - 1;
+
+  llvm::dbgs() << "inferBitCastSourceLayout: dim=" << dim
+               << ", sgData.size()=" << sgData.size()
+               << ", instData.size()=" << instData.size()
+               << ", laneData.size()=" << laneData.size() << "\n";
+
+  if (srcElemTyBitWidth <= resElemTyBitWidth) {
     int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
-    assert((laneData[dim] % bitWidthRatio) == 0 &&
-           "laneData not divisible by bitWidthRatio");
-    sgDataValue = (dim < sgData.size()) ? sgData[dim] / bitWidthRatio : -1;
-    instDataValue =
-        (dim < instData.size()) ? instData[dim] / bitWidthRatio : -1;
-    laneDataValue =
-        (dim < laneData.size()) ? laneData[dim] / bitWidthRatio : -1;
+    llvm::dbgs() << "inferBitCastSourceLayout: srcElemTyBitWidth >= "
+                    "resElemTyBitWidth, bitWidthRatio="
+                 << bitWidthRatio << "\n";
+    if (sgDataSize)
+      sgDataValue = sgData[sgDataSize - 1] * bitWidthRatio;
+    if (instDataSize)
+      instDataValue = instData[instDataSize - 1] * bitWidthRatio;
+    if (laneDataSize)
+      laneDataValue = laneData[laneDataSize - 1] * bitWidthRatio;
+    llvm::dbgs() << "inferBitCastSourceLayout: sgDataValue=" << sgDataValue
+                 << ", instDataValue=" << instDataValue
+                 << ", laneDataValue=" << laneDataValue << "\n";
+  } else {
+    int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
+    llvm::dbgs() << "inferBitCastSourceLayout: srcElemTyBitWidth < "
+                    "resElemTyBitWidth, bitWidthRatio="
+                 << bitWidthRatio << "\n";
+    if (sgDataSize) {
+      assert((sgData[sgDataSize - 1] % bitWidthRatio) == 0 &&
+             "sgData not divisible by bitWidthRatio");
+      sgDataValue = sgData[sgDataSize - 1] / bitWidthRatio;
+    }
+    if (instDataSize) {
+      assert((instData[instDataSize - 1] % bitWidthRatio) == 0 &&
+             "instData not divisible by bitWidthRatio");
+      instDataValue = instData[instDataSize - 1] / bitWidthRatio;
+    }
+    if (laneDataSize) {
+      assert((laneData[laneDataSize - 1] % bitWidthRatio) == 0 &&
+             "laneData not divisible by bitWidthRatio");
+      laneDataValue = laneData[laneDataSize - 1] / bitWidthRatio;
+    }
+    llvm::dbgs() << "inferBitCastSourceLayout: sgDataValue=" << sgDataValue
+                 << ", instDataValue=" << instDataValue
+                 << ", laneDataValue=" << laneDataValue << "\n";
   }
 
   // Now set only instData and laneData, preserving sgData
@@ -418,6 +453,7 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   finalSrcLayout =
       resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
 
+  llvm::dbgs() << "inferBitCastSourceLayout: returning finalSrcLayout\n";
   return finalSrcLayout;
 }
 
@@ -608,16 +644,18 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
 
   const int subgroupSize = uArch->getSubgroupSize();
 
-  const int vectorSize = 16; // vector size from SPRIV vector restriction
+  const int vectorSize = 1; // vector size from SPRIV vector restriction
 
   SmallVector<int64_t> defaultInstData(srcShapeSize, 1);
-  defaultInstData[srcShapeSize - 1] = subgroupSize;
-  defaultInstData[srcShapeSize - 2] =
-      vectorSize; // This will be adjusted based on actual data distribution
 
   SmallVector<int64_t> defaultLaneLayout(srcShapeSize, 1);
-  defaultLaneLayout[srcShapeSize - 1] = subgroupSize;
+
   SmallVector<int64_t> defaultLaneData(srcShapeSize, 1);
+  defaultInstData[srcShapeSize - 2] = vectorSize;
+  defaultInstData[srcShapeSize - 1] = subgroupSize;
+  defaultLaneLayout[srcShapeSize - 1] = subgroupSize;
+  defaultLaneData[srcShapeSize - 2] = vectorSize;
+  defaultLaneData[srcShapeSize - 1] = 1;
 
   // Strategy 1: Try to preserve the consumer's slice layout structure
   // If the consumer already expects a slice layout with the same reduction
@@ -743,28 +781,34 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
       }
       break;
     case xegpu::LayoutKind::Lane:
-      consumerLaneId = consumerLaneLayout.size() - 1;
-      // For non-reduction dimensions, try to match consumer's lane_layout
-      // This ensures the result after reduction has the expected distribution
-      for (int i = 0; i < srcShapeSize; i++)
-        if (!llvm::is_contained(reductionDims, i) && consumerLaneId >= 0) {
-          laneLayout[i] = consumerLaneLayout[consumerLaneId];
-          assert((srcShape[i] % laneLayout[i] == 0) &&
-                 "source shape not divisible by consumer lane_layout");
-          laneData[i] = srcShape[i] / laneLayout[i];
-          remainingLaneCount /= laneLayout[i];
-          consumerLaneId--;
-        }
-      for (int i = 0; i < srcShapeSize; i++)
-        if (llvm::is_contained(reductionDims, i)) {
-          laneLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
-                                   static_cast<int64_t>(remainingLaneCount));
-          assert((srcShape[i] % laneLayout[i] == 0) &&
-                 "source shape not divisible by consumer lane_layout");
-          laneData[i] = srcShape[i] / laneLayout[i];
-          remainingLaneCount /= laneLayout[i];
-        }
-      assert(remainingLaneCount == 1 && "not all lanes have been distributed");
+      laneLayout = defaultLaneLayout;
+      laneData = defaultLaneData;
+      //  consumerLaneId = consumerLaneLayout.size() - 1;
+      // // For non-reduction dimensions, try to match consumer's lane_layout
+      // // This ensures the result after reduction has the expected
+      // distribution for (int i = 0; i < srcShapeSize; i++)
+      //   if (!llvm::is_contained(reductionDims, i) && consumerLaneId >= 0) {
+      //     laneLayout[i] = consumerLaneLayout[consumerLaneId];
+      //     assert((srcShape[i] % laneLayout[i] == 0) &&
+      //            "source shape not divisible by consumer lane_layout");
+      //     laneData[i] = srcShape[i] / laneLayout[i];
+      //     remainingLaneCount /= laneLayout[i];
+      //     consumerLaneId--;
+      //   }
+      // for (int i = 0; i < srcShapeSize; i++)
+      //   if (llvm::is_contained(reductionDims, i)) {
+      //     laneLayout[i] =
+      //         std::min(srcShape[i],
+      //         static_cast<int64_t>(remainingLaneCount));
+      //     assert((srcShape[i] % laneLayout[i] == 0) &&
+      //            "source shape not divisible by consumer lane_layout");
+      //     laneData[i] = srcShape[i] / laneLayout[i];
+      //     laneData[i] = std::min(laneData[i],
+      //     static_cast<int64_t>(vectorSize)); remainingLaneCount /=
+      //     laneLayout[i];
+      //   }
+      // assert(remainingLaneCount == 1 && "not all lanes have been
+      // distributed");
       break;
     default:
       llvm_unreachable("unsupported layout kind");
@@ -806,17 +850,39 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
     DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
 
+  // print out consumerLayout for debugging
+  llvm::dbgs() << "setupBitCastResultLayout: consumerLayout=";
+  consumerLayout.print(llvm::dbgs());
+  llvm::dbgs() << "\n";
+
+  llvm::dbgs() << "setupBitCastResultLayout: layoutKind="
+               << static_cast<int>(layoutKind) << "\n";
+
   int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
   int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
 
+  llvm::dbgs() << "setupBitCastResultLayout: srcElemTyBitWidth="
+               << srcElemTyBitWidth
+               << ", resElemTyBitWidth=" << resElemTyBitWidth << "\n";
+
   ArrayRef<int64_t> srcShape = srcVecTy.getShape();
   SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
   SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
   SmallVector<int64_t> laneData = consumerLayout.getEffectiveLaneDataAsInt();
   size_t dim = srcShape.size() - 1;
-  int64_t sgDataValue, instDataValue, laneDataValue;
+  int64_t sgDataValue = -1;
+  int64_t instDataValue = -1;
+  int64_t laneDataValue = -1;
+
+  llvm::dbgs() << "setupBitCastResultLayout: srcShape.size()="
+               << srcShape.size() << ", dim=" << dim << "\n";
+  llvm::dbgs() << "setupBitCastResultLayout: sgData.size()=" << sgData.size()
+               << ", instData.size()=" << instData.size()
+               << ", laneData.size()=" << laneData.size() << "\n";
 
   const int subgroupSize = uArch->getSubgroupSize();
+  llvm::dbgs() << "setupBitCastResultLayout: subgroupSize=" << subgroupSize
+               << "\n";
 
   if (srcElemTyBitWidth > resElemTyBitWidth) {
     // When casting to a smaller bitwidth, multiply the result layout
@@ -853,17 +919,19 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
     default:
       llvm_unreachable("unsupported layout kind");
     }
-  } else {
-    sgDataValue = sgData[dim];
-    instDataValue = instData[dim];
-    laneDataValue = laneData[dim];
+    // Now set only instData and laneData, preserving sgData
+    xegpu::DistributeLayoutAttr resLayout;
+    llvm::dbgs() << "setupBitCastResultLayout: Setting dimension data - dim="
+                 << dim << ", sgDataValue=" << sgDataValue
+                 << ", instDataValue=" << instDataValue
+                 << ", laneDataValue=" << laneDataValue << "\n";
+    resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
+                                          laneDataValue);
+    llvm::dbgs()
+        << "setupBitCastResultLayout: resLayout created successfully\n";
+    return resLayout;
   }
-
-  // Now set only instData and laneData, preserving sgData
-  xegpu::DistributeLayoutAttr resLayout;
-  resLayout =
-      consumerLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
-  return resLayout;
+  return consumerLayout;
 }
 
 xegpu::DistributeLayoutAttr
@@ -1096,8 +1164,12 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
       assert(resShape[1] <= static_cast<int64_t>(
                                 uArchInstruction->getMaxLaneLoadStoreSize()) &&
              "StoreScatterOp lane size exceeds max lane load/store size.");
+      llvm::dbgs() << "setupLoadGatherAnchorLayout: Creating LayoutAttr with "
+                   << "laneLayout=[" << subgroupSize << ", 1], "
+                   << "laneData=[1, " << resShape[1] << "]\n";
       requiredLayout = xegpu::LayoutAttr::get(
           context, {subgroupSize, 1}, {1, static_cast<int>(resShape[1])});
+      llvm::dbgs() << "setupLoadGatherAnchorLayout: Created requiredLayout\n";
     }
     break;
   default:
@@ -1176,8 +1248,8 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
       requiredLayout =
           getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
     else {
-      assert((srcVecTy.getRank() > 2) && "StoreScatterOp can access 2D tensor "
-                                         "tile at maximum at subgroup level.");
+      assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
+                                          "tile at maximum at subgroup level.");
       assert(srcShape[1] <= static_cast<int64_t>(
                                 uArchInstruction->getMaxLaneLoadStoreSize()) &&
              "StoreScatterOp lane size exceeds max lane load/store size.");
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index acfd2e34c805c..a8eccfab53c37 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -104,21 +104,18 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
 gpu.module @test {
 // CHECK-LABEL: func.func @load_gather_with_chunksize(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
 // CHECK-SAME:  dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]]  <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
+// CHECK-NEXT: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK-NEXT: %{{.*}} = xegpu.load %arg1[%[[OFFSET]]], %[[MASK]] <{chunk_size = 16 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 16]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
   %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-  %3 = xegpu.load %2, %cst_0 : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
+  %offset = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %mask = arith.constant dense<true> : vector<16xi1>
+  %3 = xegpu.load %arg1[%offset], %mask <{chunk_size=16}>
+      : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
   %4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
   %5 = xegpu.dpas %1, %4 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
   %6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
@@ -151,16 +148,15 @@ func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf
 gpu.module @test {
 // CHECK-LABEL: func.func @store_scatter_with_chunksize(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<128xf32>) {
-// CHECK: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %{{.*}} : memref<128xf32>, vector<16xindex> ->
-// CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>,
-// CHECK-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<16xi1>
+// CHECK-NEXT: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>} dense<1.000000e+00> : vector<16x8xf32>
+// CHECK-NEXT: %[[CST_0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK-NEXT: %[[CST_1:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK-NEXT: xegpu.store %[[CST]], %[[ARG0]][%[[CST_1]]], %[[CST_0]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}> : vector<16x8xf32>, memref<128xf32>, vector<16xindex>, vector<16xi1>
 func.func @store_scatter_with_chunksize(%arg0: memref<128xf32>) {
-  %cst = arith.constant dense<1.000000e+00> : vector<16x8xf32>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
-  xegpu.store %cst, %0, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
+  %val = arith.constant dense<1.000000e+00> : vector<16x8xf32>
+  %mask = arith.constant dense<true> : vector<16xi1>
+  %offset = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  xegpu.store %val, %arg0[%offset], %mask <{chunk_size = 8}>: vector<16x8xf32>, memref<128xf32>, vector<16xindex>, vector<16xi1>
   return
 }
 }
@@ -184,9 +180,9 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
 // CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
 // CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
+// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}>
 // CHECK-SAME: memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
-// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]  <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]  <{chunk_size = 8 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 8]>}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
 func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
   %1 = arith.constant dense<1>: vector<16xi1>
   %offset = arith.constant dense<12> : vector<16xindex>
@@ -320,8 +316,9 @@ func.func @vector_bitcast_i16_to_i32(%arg0: memref<8x32xi16>, %arg1: memref<8x16
 // -----
 gpu.module @test {
 // CHECK-LABEL: func.func @vector_bitcast_require_cross_lane_shuffle(
-// CHECK:     %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<8x16xi32> -> vector<8x16xi32>
-// CHECK:     %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK:     %[[LOAD:.*]] = xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> 
+// CHECK-SAME:     !xegpu.tensor_desc<8x16xi32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK:     %{{.*}} = vector.bitcast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}
 // CHECK-SAME:     vector<8x16xi32> to vector<8x32xi16>
 func.func @vector_bitcast_require_cross_lane_shuffle(%arg0: memref<8x16xi32>, %arg1: memref<8x32xi16>) {
   %c0 = arith.constant 0 : index
@@ -703,8 +700,7 @@ func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16
 gpu.module @test {
 // CHECK-LABEL: func.func @store_matrix(
 // CHECK:         %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<16x16xf16>
-// CHECK-NEXT:     xegpu.store_matrix %[[CST]], %arg0[8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
-
+// CHECK-NEXT:     xegpu.store_matrix %[[CST]], %arg0[8, 8] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>
 func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
   %cst = arith.constant dense<0.0000> : vector<16x16xf16>
   xegpu.store_matrix %cst, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>

>From 84dc02a8d420e4a323eb39a98f86568eeb03a3d6 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 23:16:59 +0000
Subject: [PATCH 19/35] passing all tests

---
 .../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp | 10 ----------
 .../XeGPU/propagate-layout-inst-data.mlir        |  2 +-
 .../Dialect/XeGPU/subgroup-distribute-unit.mlir  | 16 ++++++++--------
 3 files changed, 9 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 80fd5feea9f97..4de0905a9b7e2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1619,16 +1619,6 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
     // must be a slice of higher rank layout.
     int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
     int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
-    // if (sourceRank < resultRank && !sourceLayout.isSliceOf(resultLayout)) {
-    //   return rewriter.notifyMatchFailure(
-    //       warpOp, "shape_cast is rank reducing but source layout is not a "
-    //               "slice of result layout");
-    // }
-    // if (sourceRank > resultRank && !resultLayout.isSliceOf(sourceLayout)) {
-    //   return rewriter.notifyMatchFailure(
-    //       warpOp, "shape_cast is rank increasing but result layout is not a "
-    //               "slice of source layout");
-    // }
 
     FailureOr<VectorType> sourceDistTypeOrFailure =
         getDistVecTypeBasedOnLaneLayout(sourceLayout,
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5aad0f592abed..3bc80e9b1596d 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -217,7 +217,7 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<1024xf32>) {
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<true> : vector<16xi1>
 // CHECK: %{{.*}} = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [16]>} dense<12> : vector<16xindex>
-// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}> :
+// CHECK: %[[LOADED:.*]] = xegpu.load %[[ARG0]][%{{.*}}], %{{.*}} <{layout = #xegpu.layout<inst_data = [16]>}> :
 // CHECK-SAME: memref<1024xf32>, vector<16xindex>, vector<16xi1> -> vector<16xf32>
 // CHECK: %[[BCAST:.*]] = vector.broadcast %[[LOADED]] {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf32> to vector<16x16xf32>
 // CHECK: xegpu.store %[[BCAST]], %[[ARG0]][%{{.*}}], %{{.*}} <{chunk_size = 16 : i64, layout = #xegpu.layout<inst_data = [16, 16]>}> :
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index b136c89925682..3a978f68ad9c0 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -579,16 +579,15 @@ gpu.func @vector_shapecast_rank_reducing(%laneid: index) {
 }
 
 
-// NOTE: Layouts are still valid, but distribution still requires a slice layout for the operand.
-//
-// CHECK-LABEL:  gpu.func @vector_shapecast_unsupported
-// CHECK:          %[[W:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>) {
-// CHECK:            %[[T1:.*]] = vector.shape_cast %{{.*}} : vector<16xf32> to vector<1x16xf32>
-// CHECK:            gpu.yield %[[T1]] : vector<1x16xf32>
+// CHECK-LABEL:  gpu.func @vector_shapecast_rank_increasing_without_slicing_layout
+// CHECK:          %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x1xf32>, vector<1xf32>) {
+// CHECK:            %[[T1:.*]] = vector.shape_cast %{{.*}} {layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf32> to vector<1x16xf32>
+// CHECK:            gpu.yield %[[T1]], %{{.*}} : vector<1x16xf32>, vector<16xf32>
 // CHECK:          }
-// CHECK:          "some_user_op"(%[[W]]) : (vector<1x1xf32>) -> ()
+// CHECK:          %{{.*}} = vector.shape_cast %[[W]]#1 : vector<1xf32> to vector<1x1xf32>
 // CHECK:          gpu.return
-gpu.func @vector_shapecast_unsupported(%laneid: index) {
+gpu.module @xevm_module{
+gpu.func @vector_shapecast_rank_increasing_without_slicing_layout(%laneid: index) {
   %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x1xf32>) {
     %cst = "some_op"()
       {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]> }
@@ -604,6 +603,7 @@ gpu.func @vector_shapecast_unsupported(%laneid: index) {
   "some_user_op"(%r) : (vector<1x1xf32>) -> ()
   gpu.return
 }
+}
 
 
 // CHECK-LABEL:  gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted

>From ed1e3a02459f09f25476b13928c73c5ac8bb760a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 29 Jan 2026 23:58:38 +0000
Subject: [PATCH 20/35] mov xegpulayoututils to transforms directory

---
 .../XeGPULayoutImpls.h}                       |   0
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  20 +-
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   1 +
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |   2 +-
 .../XeGPULayoutImpls.cpp}                     | 202 +-----------------
 .../Transforms/XeGPUPeepHoleOptimizer.cpp     |   2 +-
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |   2 +-
 .../Transforms/XeGPUSubgroupDistribute.cpp    |   2 +-
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |   2 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      |   2 +-
 mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt   |   1 -
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 200 +++++++++++++++++
 12 files changed, 218 insertions(+), 218 deletions(-)
 rename mlir/include/mlir/Dialect/XeGPU/{Utils/XeGPULayoutUtils.h => Transforms/XeGPULayoutImpls.h} (100%)
 rename mlir/lib/Dialect/XeGPU/{Utils/XeGPULayoutUtils.cpp => Transforms/XeGPULayoutImpls.cpp} (86%)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
similarity index 100%
rename from mlir/include/mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h
rename to mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 0f1ca8e38c873..572c174bccbf1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -120,16 +120,6 @@ template <typename T>
 int getLargestDivisor(T dim, ArrayRef<T> candidates,
                       ArrayRef<T> candidateMultiples = {});
 
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpResult &Result,
-                             const DistributeLayoutAttr layout);
-
-/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
-/// user should use setAnchorLayout instead
-void setDistributeLayoutAttr(const OpOperand &opr,
-                             const DistributeLayoutAttr layout);
-
 /// Retrieves the DistributeLayoutAttr associated with a given Value. For
 /// TensorDescType values, the DistributeLayoutAttr is extracted from the
 /// TensorDescType itself. For other values, it is obtained from the attributes
@@ -142,6 +132,16 @@ DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
 /// found, it will check the operand itself and its defining op.
 DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
 
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpResult &Result,
+                             const DistributeLayoutAttr layout);
+
+/// [to-be-deprecated] Sets the DistributeLayoutAttr for a given OpOperand
+/// user should use setAnchorLayout instead
+void setDistributeLayoutAttr(const OpOperand &opr,
+                             const DistributeLayoutAttr layout);
+
 /// Return the attribute name for the OpOperand to attach DistributeLayoutAttr
 std::string getTemporaryLayoutName(const OpOperand &operand);
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 15d31eadcb6df..3492bf78bf785 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUPropagateLayout.cpp
   XeGPUVectorLinearize.cpp
   XeGPUPeepHoleOptimizer.cpp
+  XeGPULayoutImpls.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7af7622375ef4..c62b8e237596e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -12,7 +12,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Pass/PassManager.h"
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
similarity index 86%
rename from mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
rename to mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 86c328616b92a..1b0ec6f11b5f9 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPULayoutUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -10,7 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/XeVMDialect.h"
@@ -28,206 +28,6 @@
 
 using namespace mlir;
 
-std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
-  const StringRef prefix("layout_operand_");
-  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-  return llvm::formatv("{0}{1}", prefix, idx).str();
-}
-
-std::string xegpu::getTemporaryLayoutName(const OpResult result) {
-  const StringRef prefix = "layout_result_";
-  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
-}
-
-xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
-  if (!value)
-    return nullptr;
-
-  if (auto tdescTy =
-          dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
-    return tdescTy.getLayoutAttr();
-
-  if (auto result = dyn_cast<OpResult>(value)) {
-    Operation *defOp = result.getDefiningOp();
-    assert(defOp && "result must have a defining op");
-
-    if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
-      auto layout = anchorOp.getAnchorLayout();
-      return layout;
-    }
-
-    std::string layoutName = getTemporaryLayoutName(result);
-    if (defOp->hasAttr(layoutName)) {
-      auto layout =
-          defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-      return layout;
-    }
-  }
-
-  if (auto arg = dyn_cast<BlockArgument>(value)) {
-    auto *parentOp = arg.getOwner()->getParentOp();
-    if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
-      OpOperand *tiedInit = loop.getTiedLoopInit(arg);
-      if (tiedInit)
-        return getDistributeLayoutAttr(tiedInit->get());
-    }
-  }
-
-  return nullptr;
-}
-xegpu::DistributeLayoutAttr
-xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
-  Operation *op = opr.getOwner();
-  unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
-
-  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
-    if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
-      if (idx == 0) {
-        return dpasOp.getLayoutAAttr();
-      } else if (idx == 1) {
-        return dpasOp.getLayoutBAttr();
-      } else if (idx == 2) {
-        return dpasOp.getLayoutCdAttr();
-      }
-    }
-    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
-      return convertOp.getInputLayoutAttr();
-    }
-    auto layout = anchorOp.getAnchorLayout();
-
-    if (idx == 0)
-      return layout;
-
-    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
-    // the layout is valid for the first two operands: value and memref/tdesc.
-    // For other operations, the layout applies to the first operand only.
-    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
-            op) &&
-        (idx < 2))
-      return layout;
-  }
-
-  std::string layoutName = xegpu::getTemporaryLayoutName(opr);
-  if (op->hasAttr(layoutName)) {
-    auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-    return layout;
-  }
-
-  auto layout = getDistributeLayoutAttr(opr.get());
-  return layout;
-}
-
-// TODO-LayoutRefactor: Remove this function after replacing use
-//  with setTemporaryLayout or setAnchorLayout
-void xegpu::setDistributeLayoutAttr(
-    const mlir::OpResult &result,
-    const mlir::xegpu::DistributeLayoutAttr layout) {
-  Operation *owner = result.getOwner();
-
-  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
-    if (anchorOp.getAnchorLayout() == layout)
-      return;
-    anchorOp.setAnchorLayout(layout);
-    return;
-  }
-
-  std::string name = xegpu::getTemporaryLayoutName(result);
-  if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
-    return;
-  }
-  if (layout) {
-    owner->setAttr(name, layout);
-  }
-}
-
-// TODO-LayoutRefactor: Remove this function after replacing use
-//  with setTemporaryLayout or setAnchorLayout
-void xegpu::setDistributeLayoutAttr(const OpOperand &operand,
-                                    const DistributeLayoutAttr layout) {
-  Operation *owner = operand.getOwner();
-  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
-
-  if (!layout) {
-    return;
-  }
-  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
-    if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
-      if (idx == 0) {
-        return dpasOp.setLayoutAAttr(layout);
-      } else if (idx == 1) {
-        return dpasOp.setLayoutBAttr(layout);
-      } else if (idx == 2) {
-        return dpasOp.setLayoutCdAttr(layout);
-      }
-    }
-    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
-      return convertOp.setInputLayoutAttr(layout);
-    }
-
-    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
-    // the layout is valid for the first two operands: value and memref/tdesc.
-    // For other operations, the layout applies to the first operand only.
-    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
-            owner)) {
-      if (idx < 2) {
-        anchorOp.setAnchorLayout(layout);
-      }
-    } else {
-      if (idx == 0) {
-        anchorOp.setAnchorLayout(layout);
-      }
-    }
-  }
-
-  std::string name = xegpu::getTemporaryLayoutName(operand);
-  if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
-    return;
-  }
-  if (layout) {
-    owner->setAttr(name, layout);
-  }
-}
-
-template <typename T, typename>
-xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout(const T &operandOrResult) {
-  Operation *op = operandOrResult.getOwner();
-
-  std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
-  if (op->hasAttr(layoutName)) {
-    auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-    return layout;
-  }
-
-  return nullptr;
-}
-
-template xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout<mlir::OpResult>(const OpResult &result);
-template xegpu::DistributeLayoutAttr
-xegpu::getTemporaryLayout<mlir::OpOperand>(const OpOperand &operand);
-
-template <typename T, typename>
-void xegpu::setTemporaryLayout(const T &operandOrResult,
-                               const xegpu::DistributeLayoutAttr layout) {
-  Operation *owner = operandOrResult.getOwner();
-  std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
-  if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
-    return;
-  }
-  if (layout) {
-    owner->setAttr(name, layout);
-  }
-}
-
-template void xegpu::setTemporaryLayout<mlir::OpResult>(
-    const mlir::OpResult &result,
-    const mlir::xegpu::DistributeLayoutAttr layout);
-
-template void xegpu::setTemporaryLayout<mlir::OpOperand>(
-    const mlir::OpOperand &operand,
-    const mlir::xegpu::DistributeLayoutAttr layout);
-
 void xegpu::recoverTemporaryLayoutsDeprecated(Operation *op) {
   op->walk([&](Operation *nestOp) {
     for (OpOperand &opr : nestOp->getOpOperands()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
index 9a8925d357d25..c1d658b0d73e7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp
@@ -16,7 +16,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/Dialect/XeGPU/uArch/uArchBase.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 4966c657c6eca..c148ba459527e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -15,7 +15,7 @@
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4de0905a9b7e2..121810215913c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -14,7 +14,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h"
 #include "mlir/IR/AffineMap.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 6bafdc955c9d3..725d72de94043 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -15,7 +15,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/DebugLog.h"
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 88dc670e0ed6a..c0487cd709c3d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -19,7 +19,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPULayoutUtils.h"
+#include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <optional>
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index bde8324aab5fb..d9bf4a1461c27 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_dialect_library(MLIRXeGPUUtils
   XeGPUUtils.cpp
-  XeGPULayoutUtils.cpp  
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU/Utils
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 181b7e9673fef..cc6ac76ec89fd 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -392,3 +392,203 @@ template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
 template int
 xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
                                    ArrayRef<unsigned> candidateMultiples);
+
+std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
+  const StringRef prefix("layout_operand_");
+  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+  return llvm::formatv("{0}{1}", prefix, idx).str();
+}
+
+std::string xegpu::getTemporaryLayoutName(const OpResult result) {
+  const StringRef prefix = "layout_result_";
+  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
+}
+
+xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
+  if (!value)
+    return nullptr;
+
+  if (auto tdescTy =
+          dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
+    return tdescTy.getLayoutAttr();
+
+  if (auto result = dyn_cast<OpResult>(value)) {
+    Operation *defOp = result.getDefiningOp();
+    assert(defOp && "result must have a defining op");
+
+    if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
+      auto layout = anchorOp.getAnchorLayout();
+      return layout;
+    }
+
+    std::string layoutName = getTemporaryLayoutName(result);
+    if (defOp->hasAttr(layoutName)) {
+      auto layout =
+          defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+      return layout;
+    }
+  }
+
+  if (auto arg = dyn_cast<BlockArgument>(value)) {
+    auto *parentOp = arg.getOwner()->getParentOp();
+    if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+      OpOperand *tiedInit = loop.getTiedLoopInit(arg);
+      if (tiedInit)
+        return getDistributeLayoutAttr(tiedInit->get());
+    }
+  }
+
+  return nullptr;
+}
+xegpu::DistributeLayoutAttr
+xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
+  Operation *op = opr.getOwner();
+  unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
+
+  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+    if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
+      if (idx == 0) {
+        return dpasOp.getLayoutAAttr();
+      } else if (idx == 1) {
+        return dpasOp.getLayoutBAttr();
+      } else if (idx == 2) {
+        return dpasOp.getLayoutCdAttr();
+      }
+    }
+    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
+      return convertOp.getInputLayoutAttr();
+    }
+    auto layout = anchorOp.getAnchorLayout();
+
+    if (idx == 0)
+      return layout;
+
+    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+    // the layout is valid for the first two operands: value and memref/tdesc.
+    // For other operations, the layout applies to the first operand only.
+    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
+            op) &&
+        (idx < 2))
+      return layout;
+  }
+
+  std::string layoutName = xegpu::getTemporaryLayoutName(opr);
+  if (op->hasAttr(layoutName)) {
+    auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+    return layout;
+  }
+
+  auto layout = getDistributeLayoutAttr(opr.get());
+  return layout;
+}
+
+// TODO-LayoutRefactor: Remove this function after replacing use
+//  with setTemporaryLayout or setAnchorLayout
+void xegpu::setDistributeLayoutAttr(
+    const mlir::OpResult &result,
+    const mlir::xegpu::DistributeLayoutAttr layout) {
+  Operation *owner = result.getOwner();
+
+  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
+    if (anchorOp.getAnchorLayout() == layout)
+      return;
+    anchorOp.setAnchorLayout(layout);
+    return;
+  }
+
+  std::string name = xegpu::getTemporaryLayoutName(result);
+  if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
+    return;
+  }
+  if (layout) {
+    owner->setAttr(name, layout);
+  }
+}
+
+// TODO-LayoutRefactor: Remove this function after replacing use
+//  with setTemporaryLayout or setAnchorLayout
+void xegpu::setDistributeLayoutAttr(const OpOperand &operand,
+                                    const DistributeLayoutAttr layout) {
+  Operation *owner = operand.getOwner();
+  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
+
+  if (!layout) {
+    return;
+  }
+  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
+    if (auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
+      if (idx == 0) {
+        return dpasOp.setLayoutAAttr(layout);
+      } else if (idx == 1) {
+        return dpasOp.setLayoutBAttr(layout);
+      } else if (idx == 2) {
+        return dpasOp.setLayoutCdAttr(layout);
+      }
+    }
+    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
+      return convertOp.setInputLayoutAttr(layout);
+    }
+
+    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+    // the layout is valid for the first two operands: value and memref/tdesc.
+    // For other operations, the layout applies to the first operand only.
+    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
+            owner)) {
+      if (idx < 2) {
+        anchorOp.setAnchorLayout(layout);
+      }
+    } else {
+      if (idx == 0) {
+        anchorOp.setAnchorLayout(layout);
+      }
+    }
+  }
+
+  std::string name = xegpu::getTemporaryLayoutName(operand);
+  if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) {
+    return;
+  }
+  if (layout) {
+    owner->setAttr(name, layout);
+  }
+}
+
+template <typename T, typename>
+xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout(const T &operandOrResult) {
+  Operation *op = operandOrResult.getOwner();
+
+  std::string layoutName = xegpu::getTemporaryLayoutName(operandOrResult);
+  if (op->hasAttr(layoutName)) {
+    auto layout = op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+    return layout;
+  }
+
+  return nullptr;
+}
+
+template xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout<mlir::OpResult>(const OpResult &result);
+template xegpu::DistributeLayoutAttr
+xegpu::getTemporaryLayout<mlir::OpOperand>(const OpOperand &operand);
+
+template <typename T, typename>
+void xegpu::setTemporaryLayout(const T &operandOrResult,
+                               const xegpu::DistributeLayoutAttr layout) {
+  Operation *owner = operandOrResult.getOwner();
+  std::string name = xegpu::getTemporaryLayoutName(operandOrResult);
+  if (owner->hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
+    return;
+  }
+  if (layout) {
+    owner->setAttr(name, layout);
+  }
+}
+
+template void xegpu::setTemporaryLayout<mlir::OpResult>(
+    const mlir::OpResult &result,
+    const mlir::xegpu::DistributeLayoutAttr layout);
+
+template void xegpu::setTemporaryLayout<mlir::OpOperand>(
+    const mlir::OpOperand &operand,
+    const mlir::xegpu::DistributeLayoutAttr layout);

>From 72d1c59de6f30ed9bc1b30df12d85d12e365cfd0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 04:10:10 +0000
Subject: [PATCH 21/35] clean up

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    |   2 +-
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 350 +++---------------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 119 +-----
 3 files changed, 68 insertions(+), 403 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 8efbb0702f0d3..8a06271eadd84 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -1190,7 +1190,7 @@ struct ConvertXeGPUToXeVMPass
         return {};
       auto input = inputs.front();
       if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
-        if (vecTy.getRank() == 1 && vecTy.getNumElements() == 1) {
+        if (vecTy.getNumElements() == 1) {
           // If the vector has a single element, return the element type.
           Value cast =
               vector::ExtractOp::create(builder, loc, input, 0).getResult();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 1b0ec6f11b5f9..45c6563ff966a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -189,10 +189,6 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   // only adjust the sg_data, inst_data, lane_data accordingly
   // based on the bitwidth ratio between source and result element type
 
-  llvm::dbgs() << "inferBitCastSourceLayout: resElemTyBitWidth="
-               << resElemTyBitWidth
-               << ", srcElemTyBitWidth=" << srcElemTyBitWidth << "\n";
-
   SmallVector<int64_t> sgData = resLayout.getEffectiveSgDataAsInt();
   SmallVector<int64_t> instData = resLayout.getEffectiveInstDataAsInt();
   SmallVector<int64_t> laneData = resLayout.getEffectiveLaneDataAsInt();
@@ -204,30 +200,16 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   int64_t laneDataValue = -1;
   int64_t dim = resLayout.getRank() - 1;
 
-  llvm::dbgs() << "inferBitCastSourceLayout: dim=" << dim
-               << ", sgData.size()=" << sgData.size()
-               << ", instData.size()=" << instData.size()
-               << ", laneData.size()=" << laneData.size() << "\n";
-
   if (srcElemTyBitWidth <= resElemTyBitWidth) {
     int bitWidthRatio = resElemTyBitWidth / srcElemTyBitWidth;
-    llvm::dbgs() << "inferBitCastSourceLayout: srcElemTyBitWidth >= "
-                    "resElemTyBitWidth, bitWidthRatio="
-                 << bitWidthRatio << "\n";
     if (sgDataSize)
       sgDataValue = sgData[sgDataSize - 1] * bitWidthRatio;
     if (instDataSize)
       instDataValue = instData[instDataSize - 1] * bitWidthRatio;
     if (laneDataSize)
       laneDataValue = laneData[laneDataSize - 1] * bitWidthRatio;
-    llvm::dbgs() << "inferBitCastSourceLayout: sgDataValue=" << sgDataValue
-                 << ", instDataValue=" << instDataValue
-                 << ", laneDataValue=" << laneDataValue << "\n";
   } else {
     int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
-    llvm::dbgs() << "inferBitCastSourceLayout: srcElemTyBitWidth < "
-                    "resElemTyBitWidth, bitWidthRatio="
-                 << bitWidthRatio << "\n";
     if (sgDataSize) {
       assert((sgData[sgDataSize - 1] % bitWidthRatio) == 0 &&
              "sgData not divisible by bitWidthRatio");
@@ -243,9 +225,6 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
              "laneData not divisible by bitWidthRatio");
       laneDataValue = laneData[laneDataSize - 1] / bitWidthRatio;
     }
-    llvm::dbgs() << "inferBitCastSourceLayout: sgDataValue=" << sgDataValue
-                 << ", instDataValue=" << instDataValue
-                 << ", laneDataValue=" << laneDataValue << "\n";
   }
 
   // Now set only instData and laneData, preserving sgData
@@ -253,7 +232,6 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   finalSrcLayout =
       resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
 
-  llvm::dbgs() << "inferBitCastSourceLayout: returning finalSrcLayout\n";
   return finalSrcLayout;
 }
 
@@ -351,40 +329,32 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   };
 
   if (checkCombineToInnerMostDim(srcShape, resShape)) {
-    const int subgroupSize = 16; // assuming 16 lanes per subgroup
     int srcShapeSize = srcShape.size();
+    int resShapeSize = resShape.size();
     auto context = resLayout.getContext();
     auto resInstData = resLayout.getEffectiveInstDataAsInt();
     auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
     auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+    if (resShapeSize == 2)
+      assert(resInstData[0] == 1 &&
+             "only innermost dim can have data combined from sources");
+
     if (resInstData.size() != 0) {
-      if (resInstData.size() == 2)
-        assert(resInstData[0] == 1 &&
-               "only innermost dim can have inst_data for combine-to-1d");
-      // construct source inst_data layout like [1, ..., 1, subgroupSize]
       SmallVector<int> inferredInstData(srcShapeSize, 1);
-      inferredInstData[srcShapeSize - 1] = resInstData[resInstData.size() - 1];
+      inferredInstData[srcShapeSize - 1] = resInstData[resShapeSize - 1];
       return xegpu::LayoutAttr::get(context, inferredInstData);
     }
 
     if (resLaneLayout.size() != 0) {
-      if (resInstData.size() == 2)
-        assert(resInstData[0] == 1 &&
-               "only innermost dim can have inst_data for combine-to-1d");
-      // construct source lane_layout like [1, ..., 1, subgroupSize]
       SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
       SmallVector<int> inferredLaneData(srcShapeSize, 1);
-      inferredLaneLayout[srcShapeSize - 1] =
-          resLaneLayout[resLaneLayout.size() - 1];
-
-      inferredLaneData[srcShapeSize - 1] =
-          resLaneLayout[resLaneLayout.size() - 1];
+      inferredLaneLayout[srcShapeSize - 1] = resLaneLayout[resShapeSize - 1];
+      inferredLaneData[srcShapeSize - 1] = resLaneData[resShapeSize - 1];
       return xegpu::LayoutAttr::get(context, inferredLaneLayout,
                                     inferredLaneData);
     }
   }
-
-  // TODO: Complete implementation for other shape cast scenarios
+  assert("running into unsupported shape cast scenarios");
   return nullptr;
 }
 
@@ -394,18 +364,12 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 /// Algorithm Overview:
 /// This function attempts to construct a source layout that, when sliced along
 /// reduction dimensions, produces a result layout compatible with the
-/// consumer's preferred layout. This minimizes data redistribution overhead
+/// consumer layout. This minimizes data redistribution overhead
 /// between the reduction operation and its consumer.
 ///
-/// Strategy:
-/// 1. First, check if the consumer's preferred layout is already a SliceAttr
-///    with matching reduction dimensions. If so, use its parent layout directly
-///    and adjust the sg_data/inst_data acccording to source shape.
-/// 2. If step 1 fails, construct a new layout by distributing
-///    workgroup/subgroup resources across dimensions. It will try to align
-///    with the consumer's sg_layout for non-reduction dimensions, so that the
-///    reduced result stays with the same subgroup distribution as expected by
-///    the consumer.
+/// It first tries to align the source layout's subgroup layout and data with
+/// the consumer's layout on non-reduction dimensions. Then, it distributes
+/// remaining subgroups across reduction dimensions.
 
 xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy,
@@ -457,162 +421,52 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
   defaultLaneData[srcShapeSize - 2] = vectorSize;
   defaultLaneData[srcShapeSize - 1] = 1;
 
-  // Strategy 1: Try to preserve the consumer's slice layout structure
-  // If the consumer already expects a slice layout with the same reduction
-  // dims, we can directly use its parent layout as our source layout. This
-  // ensures best alignment and avoids any data movement across subgroups
-  // and lanes.
-
-  auto canPreserveSliceLayout =
-      [&](ArrayRef<int64_t> srcShape, SmallVector<int64_t> reductionDims,
-          DistributeLayoutAttr consumerLayout) -> bool {
-    // Verify that the consumer layout can be adapted to the source shape:
-    // For each dimension, check if srcShape[i] is divisible by the parent's
-    // sg_layout[i] and lane_layout[i]. If so, sg_layout and lane_layout can be
-    // reused.
-
-    if (!consumerSliceLayout)
-      return false;
-    if (!consumerSliceLayout.getDims().asArrayRef().equals(reductionDims))
-      return false;
-    xegpu::DistributeLayoutAttr parentLayout = consumerSliceLayout.getParent();
-    if (parentLayout.getRank() != srcShapeSize)
-      return false;
-
-    SmallVector<int64_t> parentSgLayout =
-        parentLayout.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> parentLaneLayout =
-        parentLayout.getEffectiveLaneLayoutAsInt();
-
-    if (parentSgLayout.size() != static_cast<size_t>(srcShapeSize))
-      return false;
-    if (parentLaneLayout.size() != static_cast<size_t>(srcShapeSize))
-      return false;
-    for (int i = 0; i < srcShapeSize; i++) {
-      if (srcShape[i] % parentSgLayout[i] != 0)
-        return false;
-      if (instData[i] % parentLaneLayout[i] != 0)
-        return false;
-    }
-    return true;
-  };
-
-  if (canPreserveSliceLayout(srcShape, reductionDims, consumerLayout)) {
-    // for each slice dim in source shape, if the dim size is different than the
-    // result shape, try to adjust the sg_data/inst_data accordingly.
-    SmallVector<int64_t> parentSgLayout =
-        consumerSliceLayout.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> parentLaneLayout =
-        consumerSliceLayout.getEffectiveLaneLayoutAsInt();
-    SmallVector<int64_t> parentLaneData =
-        consumerSliceLayout.getEffectiveLaneDataAsInt();
+  SmallVector<int64_t> consumerSgLayout =
+      consumerLayout.getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> consumerLaneLayout =
+      consumerLayout.getEffectiveLaneLayoutAsInt();
+  int remainingSgCount = workgroupSize;
+  int consumerSgId;
 
-    switch (layoutKind) {
-    case xegpu::LayoutKind::Subgroup:
-      sgLayout = parentSgLayout;
-      for (int i = 0; i < srcShapeSize; i++)
-        if (llvm::is_contained(reductionDims, i))
-          sgData[i] = srcShape[i] / sgLayout[i];
-        else
-          sgData[i] = parentSgLayout[i];
-      break;
-    case xegpu::LayoutKind::InstData:
-      for (int i = 0; i < srcShapeSize; i++)
-        instData[i] = std::min(defaultInstData[i], srcShape[i]);
-      break;
-    case xegpu::LayoutKind::Lane:
-      laneLayout = parentLaneLayout;
-      for (int i = 0; i < srcShapeSize; i++) {
-        assert((srcShape[i] % laneLayout[i] == 0) &&
-               "source shape not divisible by lane layout");
-        laneData[i] = srcShape[i] / laneLayout[i];
+  switch (layoutKind) {
+  case xegpu::LayoutKind::Subgroup:
+    consumerSgId = consumerSgLayout.size() - 1;
+    // For non-reduction dimensions, try to match consumer's sg_layout
+    // This ensures the result after reduction has the expected distribution
+    for (int i = srcShapeSize - 1; i >= 0; i--)
+      if (!llvm::is_contained(reductionDims, i) && consumerSgId >= 0) {
+        sgLayout[i] = consumerSgLayout[consumerSgId];
+        assert((srcShape[i] % sgLayout[i] == 0) &&
+               "source shape not divisible by consumer sg_layout");
+        sgData[i] = srcShape[i] / sgLayout[i];
+        remainingSgCount /= sgLayout[i];
+        consumerSgId--;
       }
-      break;
-    default:
-      llvm_unreachable("unsupported layout kind");
-    }
-  } else {
-    // Strategy 2: Construct a new layout aligned with consumer's sg_layout for
-    // the result (non-reduction dims) then distribute remaining subgroups
-    // across reduced dimensions
-
-    SmallVector<int64_t> consumerSgLayout =
-        consumerLayout.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> consumerLaneLayout =
-        consumerLayout.getEffectiveLaneLayoutAsInt();
-    int remainingSgCount = workgroupSize;
-    int remainingLaneCount = subgroupSize;
-    int consumerSgId, consumerLaneId;
-
-    switch (layoutKind) {
-    case xegpu::LayoutKind::Subgroup:
-      consumerSgId = consumerSgLayout.size() - 1;
-      // For non-reduction dimensions, try to match consumer's sg_layout
-      // This ensures the result after reduction has the expected distribution
-      for (int i = srcShapeSize - 1; i >= 0; i--)
-        if (!llvm::is_contained(reductionDims, i) && consumerSgId >= 0) {
-          sgLayout[i] = consumerSgLayout[consumerSgId];
-          assert((srcShape[i] % sgLayout[i] == 0) &&
-                 "source shape not divisible by consumer sg_layout");
-          sgData[i] = srcShape[i] / sgLayout[i];
-          remainingSgCount /= sgLayout[i];
-          consumerSgId--;
-        }
-      // Second pass: Distribute remaining subgroups across unhandled dimensions
-      // This handles reduction dimensions that don't necessarily align with
-      // consumer
-      for (int i = srcShapeSize - 1; i >= 0; i--)
-        if (llvm::is_contained(reductionDims, i)) {
-          sgLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
-                                 static_cast<int64_t>(remainingSgCount));
-          assert((srcShape[i] % sgLayout[i] == 0) &&
-                 "source shape not divisible by consumer sg_layout");
-          sgData[i] = srcShape[i] / sgLayout[i];
-          remainingSgCount /= sgLayout[i];
-        }
-      assert(remainingSgCount == 1 &&
-             "not all subgroups have been distributed");
-      break;
-    case xegpu::LayoutKind::InstData:
-      for (int i = 0; i < srcShapeSize; i++) {
-        instData[i] = std::min(defaultInstData[i], srcShape[i]);
-        llvm::dbgs() << "MultiReductionOp: Strategy 2: instData [" << i
-                     << "] = " << instData[i] << "\n";
+    // Second pass: Distribute remaining subgroups across unhandled dimensions
+    // This handles reduction dimensions that don't necessarily align with
+    // consumer
+    for (int i = srcShapeSize - 1; i >= 0; i--)
+      if (llvm::is_contained(reductionDims, i)) {
+        sgLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
+                               static_cast<int64_t>(remainingSgCount));
+        assert((srcShape[i] % sgLayout[i] == 0) &&
+               "source shape not divisible by consumer sg_layout");
+        sgData[i] = srcShape[i] / sgLayout[i];
+        remainingSgCount /= sgLayout[i];
       }
-      break;
-    case xegpu::LayoutKind::Lane:
-      laneLayout = defaultLaneLayout;
-      laneData = defaultLaneData;
-      //  consumerLaneId = consumerLaneLayout.size() - 1;
-      // // For non-reduction dimensions, try to match consumer's lane_layout
-      // // This ensures the result after reduction has the expected
-      // distribution for (int i = 0; i < srcShapeSize; i++)
-      //   if (!llvm::is_contained(reductionDims, i) && consumerLaneId >= 0) {
-      //     laneLayout[i] = consumerLaneLayout[consumerLaneId];
-      //     assert((srcShape[i] % laneLayout[i] == 0) &&
-      //            "source shape not divisible by consumer lane_layout");
-      //     laneData[i] = srcShape[i] / laneLayout[i];
-      //     remainingLaneCount /= laneLayout[i];
-      //     consumerLaneId--;
-      //   }
-      // for (int i = 0; i < srcShapeSize; i++)
-      //   if (llvm::is_contained(reductionDims, i)) {
-      //     laneLayout[i] =
-      //         std::min(srcShape[i],
-      //         static_cast<int64_t>(remainingLaneCount));
-      //     assert((srcShape[i] % laneLayout[i] == 0) &&
-      //            "source shape not divisible by consumer lane_layout");
-      //     laneData[i] = srcShape[i] / laneLayout[i];
-      //     laneData[i] = std::min(laneData[i],
-      //     static_cast<int64_t>(vectorSize)); remainingLaneCount /=
-      //     laneLayout[i];
-      //   }
-      // assert(remainingLaneCount == 1 && "not all lanes have been
-      // distributed");
-      break;
-    default:
-      llvm_unreachable("unsupported layout kind");
+    assert(remainingSgCount == 1 && "not all subgroups have been distributed");
+    break;
+  case xegpu::LayoutKind::InstData:
+    for (int i = 0; i < srcShapeSize; i++) {
+      instData[i] = std::min(defaultInstData[i], srcShape[i]);
     }
+    break;
+  case xegpu::LayoutKind::Lane:
+    laneLayout = defaultLaneLayout;
+    laneData = defaultLaneData;
+    break;
+  default:
+    llvm_unreachable("unsupported layout kind");
   }
 
   SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
@@ -650,21 +504,9 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
     DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
 
-  // print out consumerLayout for debugging
-  llvm::dbgs() << "setupBitCastResultLayout: consumerLayout=";
-  consumerLayout.print(llvm::dbgs());
-  llvm::dbgs() << "\n";
-
-  llvm::dbgs() << "setupBitCastResultLayout: layoutKind="
-               << static_cast<int>(layoutKind) << "\n";
-
   int srcElemTyBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
   int resElemTyBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
 
-  llvm::dbgs() << "setupBitCastResultLayout: srcElemTyBitWidth="
-               << srcElemTyBitWidth
-               << ", resElemTyBitWidth=" << resElemTyBitWidth << "\n";
-
   ArrayRef<int64_t> srcShape = srcVecTy.getShape();
   SmallVector<int64_t> sgData = consumerLayout.getEffectiveSgDataAsInt();
   SmallVector<int64_t> instData = consumerLayout.getEffectiveInstDataAsInt();
@@ -674,15 +516,7 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
   int64_t instDataValue = -1;
   int64_t laneDataValue = -1;
 
-  llvm::dbgs() << "setupBitCastResultLayout: srcShape.size()="
-               << srcShape.size() << ", dim=" << dim << "\n";
-  llvm::dbgs() << "setupBitCastResultLayout: sgData.size()=" << sgData.size()
-               << ", instData.size()=" << instData.size()
-               << ", laneData.size()=" << laneData.size() << "\n";
-
   const int subgroupSize = uArch->getSubgroupSize();
-  llvm::dbgs() << "setupBitCastResultLayout: subgroupSize=" << subgroupSize
-               << "\n";
 
   if (srcElemTyBitWidth > resElemTyBitWidth) {
     // When casting to a smaller bitwidth, multiply the result layout
@@ -705,7 +539,7 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
       while ((instDataValue <= srcShape[dim]) &&
              (instDataValue % (innermostDimLaneLayout * bitWidthRatio) != 0))
         instDataValue *= 2;
-      assert(srcShape[dim] % instDataValue == 0 &&
+      assert((srcShape[dim] % instDataValue) == 0 &&
              "srcShape, instData, and lanelayout for innermost must be 2^n !");
       break;
     case xegpu::LayoutKind::Lane:
@@ -721,14 +555,8 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
     }
     // Now set only instData and laneData, preserving sgData
     xegpu::DistributeLayoutAttr resLayout;
-    llvm::dbgs() << "setupBitCastResultLayout: Setting dimension data - dim="
-                 << dim << ", sgDataValue=" << sgDataValue
-                 << ", instDataValue=" << instDataValue
-                 << ", laneDataValue=" << laneDataValue << "\n";
     resLayout = consumerLayout.setDimData(dim, sgDataValue, instDataValue,
                                           laneDataValue);
-    llvm::dbgs()
-        << "setupBitCastResultLayout: resLayout created successfully\n";
     return resLayout;
   }
   return consumerLayout;
@@ -804,10 +632,10 @@ xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
   auto context = resVectorTy.getContext();
   auto resShape = resVectorTy.getShape();
   int resShapeSize = resShape.size();
+
   SmallVector<int> defaultInstData(resShapeSize, 1);
   SmallVector<int> defaultLaneLayout(resShapeSize, 1);
   SmallVector<int> defaultLaneData(resShapeSize, 1);
-
   defaultInstData[resShapeSize - 1] = subgroupSize;
   defaultLaneLayout[resShapeSize - 1] = subgroupSize;
 
@@ -833,7 +661,6 @@ xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
     xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
     ArrayRef<int64_t> srcShape) {
 
-  const int subgroupSize = 16; // assuming 16 lanes per subgroup
   int srcShapeSize = srcShape.size();
   int resShapeSize = resShape.size();
   int dimDiff = resShapeSize - srcShapeSize;
@@ -883,28 +710,14 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
     LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
 
-  llvm::dbgs() << "setupLoadGatherAnchorLayout: layoutKind="
-               << static_cast<int>(layoutKind) << ", chunkSize=" << chunkSize
-               << "\n";
-
   xegpu::DistributeLayoutAttr requiredLayout;
   const int subgroupSize = uArch->getSubgroupSize();
-  const int spirVectorSize = 16; // vector size from SPRIV vector restriction
 
   auto resShape = resVecTy.getShape();
   int resShapeSize = resShape.size();
   SmallVector<int> instData(resShapeSize);
   auto context = resVecTy.getContext();
 
-  llvm::dbgs() << "setupLoadGatherAnchorLayout: resVecTy.getRank()="
-               << resVecTy.getRank() << ", resShape=[";
-  for (size_t i = 0; i < resShape.size(); ++i) {
-    if (i > 0)
-      llvm::dbgs() << ", ";
-    llvm::dbgs() << resShape[i];
-  }
-  llvm::dbgs() << "]\n";
-
   const auto *uArchInstruction =
       dyn_cast<xegpu::uArch::StoreScatterInstruction>(
           uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
@@ -913,48 +726,31 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
       consumerLayout.getEffectiveInstDataAsInt();
   SmallVector<int32_t> instData32;
 
-  llvm::dbgs() << "setupLoadGatherAnchorLayout: consumerInstData=[";
-  for (size_t i = 0; i < consumerInstData.size(); ++i) {
-    if (i > 0)
-      llvm::dbgs() << ", ";
-    llvm::dbgs() << consumerInstData[i];
-  }
-  llvm::dbgs() << "]\n";
-
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
-    llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::Subgroup\n";
     requiredLayout = consumerLayout;
     break;
   case xegpu::LayoutKind::InstData:
-    llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::InstData\n";
     if (resVecTy.getRank() == 1) {
       instData[0] = subgroupSize;
-      llvm::dbgs() << "setupLoadGatherAnchorLayout: 1D case, instData[0]="
-                   << instData[0] << "\n";
     } else {
       assert((resVecTy.getRank() == 2) && "StoreScatterOp can access 2D tensor "
                                           "tile at maximum at subgroup level.");
       if (chunkSize == 1) {
         instData[0] = 1;
         instData[1] = subgroupSize;
-        llvm::dbgs() << "setupLoadGatherAnchorLayout: chunkSize==1, instData=["
-                     << instData[0] << ", " << instData[1] << "]\n";
       } else {
         instData[0] = subgroupSize;
         instData[1] = std::min(static_cast<int>(resShape[1]),
                                uArchInstruction->getMaxLaneLoadStoreSize());
         instData[1] =
             std::min(instData[1], static_cast<int>(consumerInstData[1]));
-        llvm::dbgs() << "setupLoadGatherAnchorLayout: chunkSize>1, instData=["
-                     << instData[0] << ", " << instData[1] << "]\n";
       }
     }
     requiredLayout = xegpu::LayoutAttr::get(
         context, DenseI32ArrayAttr::get(context, instData));
     break;
   case xegpu::LayoutKind::Lane:
-    llvm::dbgs() << "setupLoadGatherAnchorLayout: LayoutKind::Lane\n";
     if (chunkSize == 1)
       requiredLayout =
           getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
@@ -964,18 +760,13 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
       assert(resShape[1] <= static_cast<int64_t>(
                                 uArchInstruction->getMaxLaneLoadStoreSize()) &&
              "StoreScatterOp lane size exceeds max lane load/store size.");
-      llvm::dbgs() << "setupLoadGatherAnchorLayout: Creating LayoutAttr with "
-                   << "laneLayout=[" << subgroupSize << ", 1], "
-                   << "laneData=[1, " << resShape[1] << "]\n";
       requiredLayout = xegpu::LayoutAttr::get(
           context, {subgroupSize, 1}, {1, static_cast<int>(resShape[1])});
-      llvm::dbgs() << "setupLoadGatherAnchorLayout: Created requiredLayout\n";
     }
     break;
   default:
     llvm_unreachable("unsupported layout kind");
   }
-  llvm::dbgs() << "setupLoadGatherAnchorLayout: returning requiredLayout\n";
   return requiredLayout;
 }
 
@@ -983,27 +774,13 @@ xegpu::DistributeLayoutAttr
 xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
                                      int chunkSize, const uArch::uArch *uArch) {
 
-  llvm::dbgs() << "setupStoreScatterAnchorLayout: layoutKind="
-               << static_cast<int>(layoutKind) << ", chunkSize=" << chunkSize
-               << "\n";
-
   xegpu::DistributeLayoutAttr requiredLayout;
   const int subgroupSize = uArch->getSubgroupSize();
-  const int spirVectorSize = 16; // vector size from SPRIV vector restriction
 
   auto srcShape = srcVecTy.getShape();
   int srcShapeSize = srcShape.size();
   SmallVector<int> instData(srcShapeSize);
 
-  llvm::dbgs() << "setupStoreScatterAnchorLayout: srcVecTy.getRank()="
-               << srcVecTy.getRank() << ", srcShape=[";
-  for (size_t i = 0; i < srcShape.size(); ++i) {
-    if (i > 0)
-      llvm::dbgs() << ", ";
-    llvm::dbgs() << srcShape[i];
-  }
-  llvm::dbgs() << "]\n";
-
   const auto *uArchInstruction =
       dyn_cast<xegpu::uArch::StoreScatterInstruction>(
           uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
@@ -1011,39 +788,29 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
 
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
-    llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::Subgroup\n";
     assert(
-        false &&
+        true &&
         "subgroup layout assignment not supported yet for store scatter op.");
     break;
   case xegpu::LayoutKind::InstData:
-    llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::InstData\n";
     if (srcVecTy.getRank() == 1) {
       instData[0] = subgroupSize;
-      llvm::dbgs() << "setupStoreScatterAnchorLayout: 1D case, instData[0]="
-                   << instData[0] << "\n";
     } else {
       assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
                                           "tile at maximum at subgroup level.");
       if (chunkSize == 1) {
         instData[0] = 1;
         instData[1] = subgroupSize;
-        llvm::dbgs()
-            << "setupStoreScatterAnchorLayout: chunkSize==1, instData=["
-            << instData[0] << ", " << instData[1] << "]\n";
       } else {
         instData[0] = subgroupSize;
         instData[1] = std::min(static_cast<int>(srcShape[1]),
                                uArchInstruction->getMaxLaneLoadStoreSize());
-        llvm::dbgs() << "setupStoreScatterAnchorLayout: chunkSize>1, instData=["
-                     << instData[0] << ", " << instData[1] << "]\n";
       }
     }
     requiredLayout = xegpu::LayoutAttr::get(
         context, DenseI32ArrayAttr::get(context, instData));
     break;
   case xegpu::LayoutKind::Lane:
-    llvm::dbgs() << "setupStoreScatterAnchorLayout: LayoutKind::Lane\n";
     if (chunkSize == 1)
       requiredLayout =
           getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
@@ -1060,6 +827,5 @@ xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
   default:
     llvm_unreachable("unsupported layout kind");
   }
-  llvm::dbgs() << "setupStoreScatterAnchorLayout: returning requiredLayout\n";
   return requiredLayout;
 }
\ No newline at end of file
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index c148ba459527e..6385995317f38 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -671,61 +671,29 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
   if (!resLayoutInfo.isAssigned())
     return;
 
-  // debug print resLayoutInfo
-  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: resLayoutInfo = ";
-             resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
-
   VectorType sourceTy =
       llvm::dyn_cast<VectorType>(reduction.getSourceVectorType());
   SmallVector<int64_t> reductionDims(reduction.getReductionDims().begin(),
                                      reduction.getReductionDims().end());
 
-  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcShape = [";
-             for (auto dim
-                  : sourceTy.getShape()) llvm::dbgs()
-             << dim << " ";
-             llvm::dbgs() << "]\n");
-  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: reductionDims = [";
-             for (auto dim
-                  : reductionDims) llvm::dbgs()
-             << dim << " ";
-             llvm::dbgs() << "]\n");
-
+  auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
   auto consumerLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
 
-  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: consumerLayoutAttr = "
-                    << consumerLayoutAttr << "\n");
-  // The required result layout represents the layout requirements
-  // for the operation it is recorded to anchor layout or temporary layout.
+  // The result layout represents the layout requirements of the operation.
+  // it is recorded to anchor layout or temporary layout.
   // it must be honored for current op and may conflict with the layout
   // propagated from consumer op, the conflict is resolved in later phase by
   // converting the required result layout to the consumer layout
-  auto uArch = getUArch(xegpu::getChipStr(reduction).value_or(""));
   auto requiredResLayoutAttr = xegpu::setupMultiReductionResultLayout(
       layoutKind, sourceTy, consumerLayoutAttr, reductionDims, uArch);
-  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: requiredResLayoutAttr = "
-                    << requiredResLayoutAttr << "\n");
-
   // resLayoutInfo.set(requiredResLayoutAttr);
   xegpu::setTemporaryLayout(reduction->getResult(0), requiredResLayoutAttr);
 
-  // debug print resLayoutInfo
-  LLVM_DEBUG(
-      DBGS() << "visitVectorMultiReductionOp: after change resLayoutInfo = ";
-      resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
-
-  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: after change "
-                       "results[0]->getValue() = ";
-             results[0]->getValue().print(llvm::dbgs()); llvm::dbgs() << "\n");
-
   // derive the source layout from the dominant layout and reduction dims
   auto srcLayoutAttr = xegpu::inferMultiReductionSourceLayout(
       requiredResLayoutAttr, reductionDims);
 
-  LLVM_DEBUG(DBGS() << "visitVectorMultiReductionOp: srcLayoutAttr = "
-                    << srcLayoutAttr << "\n");
-
   propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(srcLayoutAttr)));
   // Accumulator should have the same layout as the result.
   propagateIfChanged(operands[1],
@@ -1291,7 +1259,6 @@ void LayoutInfoPropagation::visitCreateDescOp(
 void LayoutInfoPropagation::visitStoreScatterOp(
     xegpu::StoreScatterOp storeScatter, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Processing store scatter op\n");
 
   xegpu::DistributeLayoutAttr requiredAnchorLayoutAttr;
   xegpu::DistributeLayoutAttr anchorLayoutAttr = storeScatter.getLayoutAttr();
@@ -1302,19 +1269,9 @@ void LayoutInfoPropagation::visitStoreScatterOp(
       llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
   int chunkSize = storeScatter.getChunkSize().value_or(1);
 
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: anchorLayoutAttr = "
-                    << anchorLayoutAttr << "\n");
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: subgroupSize = " << subgroupSize
-                    << "\n");
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: chunkSize = " << chunkSize
-                    << "\n");
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: srcVecTy = " << srcVecTy << "\n");
-
   if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
-    LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Using existing anchor layout\n");
     requiredAnchorLayoutAttr = anchorLayoutAttr;
   } else {
-    LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Setting up new anchor layout\n");
     if (!srcVecTy) {
       storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
       return;
@@ -1324,16 +1281,11 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     storeScatter.setLayoutAttr(requiredAnchorLayoutAttr);
   }
 
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: requiredAnchorLayoutAttr = "
-                    << requiredAnchorLayoutAttr << "\n");
-
   LayoutInfo srcLayoutInfo = LayoutInfo(requiredAnchorLayoutAttr);
   auto maskLayoutAttr = requiredAnchorLayoutAttr;
   // Special handling mask layout for chunked ops: Enforce the default xegpu 1D
   // layout for mask.
   if (chunkSize > 1) {
-    LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Setting mask layout for chunked "
-                         "operation\n");
     if (layoutKind == xegpu::LayoutKind::InstData)
       maskLayoutAttr =
           xegpu::LayoutAttr::get(storeScatter->getContext(), {subgroupSize});
@@ -1346,71 +1298,39 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   }
 
   LayoutInfo maskLayoutInfo = LayoutInfo(maskLayoutAttr);
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: maskLayoutAttr = "
-                    << maskLayoutAttr << "\n");
 
   // Propagate the payload operand layout
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating payload layout\n");
   propagateIfChanged(operands[0], operands[0]->meet(srcLayoutInfo));
   // Propagate the destination (if tdesc) operand layout
-  if (isa<xegpu::TensorDescType>(storeScatter.getDestType())) {
-    LLVM_DEBUG(
-        DBGS() << "visitStoreScatterOp: Propagating destination layout\n");
+  if (isa<xegpu::TensorDescType>(storeScatter.getDestType()))
     propagateIfChanged(operands[1], operands[1]->meet(srcLayoutInfo));
-  }
   // Propagate the new layout to the mask and optional offset operand.
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating mask layout\n");
   propagateIfChanged(operands[2], operands[2]->meet(maskLayoutInfo));
-  if (storeScatter.getOffsets()) {
-    LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Propagating offset layout\n");
+  if (storeScatter.getOffsets())
     propagateIfChanged(operands[3], operands[3]->meet(maskLayoutInfo));
-  }
-  LLVM_DEBUG(DBGS() << "visitStoreScatterOp: Done\n");
 }
 
 void LayoutInfoPropagation::visitLoadMatrixOp(
     xegpu::LoadMatrixOp loadMatrixOp, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Processing load matrix op\n");
 
   LayoutInfo resLayoutInfo = results[0]->getValue();
-  if (!resLayoutInfo.isAssigned()) {
-    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Result layout not assigned\n");
-    return;
-  }
-
-  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: resLayoutInfo = ";
-             resLayoutInfo.print(llvm::dbgs()); llvm::dbgs() << "\n");
-
   auto consumerLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
 
-  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: consumerLayoutAttr = "
-                    << consumerLayoutAttr << "\n");
-
   xegpu::DistributeLayoutAttr anchorLayout = loadMatrixOp.getLayoutAttr();
 
-  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: anchorLayout = " << anchorLayout
-                    << "\n");
-
   // only need to set anchor layout, no need to porpagate to memdesc and
   // offset
   if (!hasParamsOfLayoutKind(anchorLayout)) {
-    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Setting up new anchor layout\n");
     VectorType resVecTy =
         llvm::cast<VectorType>(loadMatrixOp.getRes().getType());
     assert(resVecTy.getRank() == 2 && "Expecting 2D vector for store matrix.");
-    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: resVecTy = " << resVecTy << "\n");
     auto uArch = getUArch(getChipStr(loadMatrixOp).value_or(""));
     auto requiredAnchorLayoutAttr = xegpu::setupLoadMatrixAnchorLayout(
         layoutKind, resVecTy, consumerLayoutAttr, uArch);
-    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: requiredAnchorLayoutAttr = "
-                      << requiredAnchorLayoutAttr << "\n");
     loadMatrixOp.setLayoutAttr(requiredAnchorLayoutAttr);
-  } else {
-    LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Using existing anchor layout\n");
   }
-  LLVM_DEBUG(DBGS() << "visitLoadMatrixOp: Done\n");
 }
 
 // Store matrix is a flavor of scattered store for 2D shapes.
@@ -1712,42 +1632,21 @@ void XeGPUPropagateLayoutPass::runOnOperation() {
     if (auto opResult = dyn_cast<OpResult>(val)) {
 
       Operation *defOp = opResult.getDefiningOp();
-      LLVM_DEBUG(DBGS() << "Try op: " << *defOp << "\n");
       if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
-        LLVM_DEBUG(DBGS() << "AnchorLayoutInterface found for op: " << *defOp
-                          << "\n");
         auto anchorLayout = anchorOp.getAnchorLayout();
-        LLVM_DEBUG(DBGS() << "Anchor layout: " << anchorLayout << "\n");
-        LLVM_DEBUG(DBGS() << "Anchor layout is null: "
-                          << (anchorLayout == nullptr ? "true" : "false")
-                          << "\n");
-        if (anchorLayout != nullptr) {
-          LLVM_DEBUG(DBGS()
-                     << "Returning anchor layout: " << anchorLayout << "\n");
+        if (anchorLayout != nullptr)
           return anchorLayout;
-        }
-        LLVM_DEBUG(DBGS() << "Anchor layout is null, continuing...\n");
       }
-
       xegpu::DistributeLayoutAttr requiredResLayoutAttr =
           xegpu::getTemporaryLayout(opResult);
-      LLVM_DEBUG(DBGS() << "Temporary layout for value: " << val << " is "
-                        << requiredResLayoutAttr << "\n");
-      if (requiredResLayoutAttr != nullptr) {
-        LLVM_DEBUG(DBGS() << "Returning temporary layout: "
-                          << requiredResLayoutAttr << "\n");
+      if (requiredResLayoutAttr != nullptr)
         return requiredResLayoutAttr;
-      }
     }
     xegpu::DistributeLayoutAttr layoutAttr =
         cast<xegpu::DistributeLayoutAttr>(layout.get());
-    LLVM_DEBUG(DBGS() << "Layout attr for value: " << val << " is "
-                      << layoutAttr << "\n");
-    if (layout.isSliceLayout()) {
-      LLVM_DEBUG(DBGS() << "Returning slice layout: " << layoutAttr << "\n");
+    if (layout.isSliceLayout())
       return cast<xegpu::SliceAttr>(layoutAttr);
-    }
-    LLVM_DEBUG(DBGS() << "Returning layout attr: " << layoutAttr << "\n");
+
     return cast<xegpu::LayoutAttr>(layoutAttr);
   };
 

>From 4a945718a1a602c3269cbf6724e188a63d7298f8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 04:20:42 +0000
Subject: [PATCH 22/35] fix creatememdesc and memalloca distribution issue

---
 .../XeGPU/Transforms/XeGPUSubgroupDistribute.cpp  | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 121810215913c..cc278c6813fc9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1908,7 +1908,12 @@ struct MemrefAllocaDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<memref::AllocaOp>);
+    OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+      // Check if the yield operand that was produced by the *last* scattered
+      // load op to avoid creating multiple copies due to multiple users.
+      return llvm::IsaPred<memref::AllocaOp>(op) &&
+             warpOp.getTerminator()->getPrevNode() == op;
+    });
     if (!operand)
       return rewriter.notifyMatchFailure(
           warpOp, "warp result is not a memref::Alloca op");
@@ -1930,8 +1935,12 @@ struct CreateMemDescDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    OpOperand *operand =
-        getWarpResult(warpOp, llvm::IsaPred<xegpu::CreateMemDescOp>);
+    OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
+      // Check if the yield operand that was produced by the *last* scattered
+      // load op to avoid creating multiple copies due to multiple users.
+      return llvm::IsaPred<xegpu::CreateMemDescOp>(op) &&
+             warpOp.getTerminator()->getPrevNode() == op;
+    });
     if (!operand)
       return rewriter.notifyMatchFailure(
           warpOp, "warp result is not a xegpu::CreateMemDesc op");

>From 20c4fbdd75663b9297b7cb621f2dc7886063e178 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 06:34:24 +0000
Subject: [PATCH 23/35] add comments & polish

---
 .../XeGPU/Transforms/XeGPULayoutImpls.h       |   3 +
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 245 ++++++++----------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  36 ---
 .../Transforms/XeGPUSubgroupDistribute.cpp    |   5 -
 4 files changed, 116 insertions(+), 173 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index 110496bb34fb3..e5f2b1be93cc2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -90,6 +90,9 @@ DistributeLayoutAttr inferShapeCastSourceLayout(DistributeLayoutAttr resLayout,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
 
+/// Infers the source layout attribute for an insert strided slice operation
+/// given the result layout attribute, result shape, and source shape. Removes
+/// leading dimensions from the result layout to match the source shape size.
 DistributeLayoutAttr
 inferInsertStridedSliceSourceLayout(DistributeLayoutAttr resLayout,
                                     ArrayRef<int64_t> resShape,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 45c6563ff966a..8509f709c6fec 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -235,6 +235,48 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   return finalSrcLayout;
 }
 
+/// Infers the source layout attribute for an insert strided slice operation
+/// given the result layout attribute, result shape, and source shape. Removes
+/// leading dimensions from the result layout to match the source shape size.
+xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
+    xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
+    ArrayRef<int64_t> srcShape) {
+
+  int srcShapeSize = srcShape.size();
+  int resShapeSize = resShape.size();
+  int dimDiff = resShapeSize - srcShapeSize;
+
+  // assert resLayout must be a plain layout
+  assert(isa<xegpu::LayoutAttr>(resLayout) &&
+         "insertStridedSlice result layout must be plain layout");
+  auto context = resLayout.getContext();
+  auto resInstData = resLayout.getEffectiveInstDataAsInt();
+  auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
+  auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
+
+  if (resInstData.size() != 0) {
+    SmallVector<int> inferredInstData(srcShapeSize);
+    // remove the initial dims in resInstData to match srcShapeSize
+    for (int i = 0; i < srcShapeSize; i++)
+      inferredInstData[i] = resInstData[i + dimDiff];
+    return xegpu::LayoutAttr::get(context, inferredInstData);
+  }
+
+  if (resLaneLayout.size() != 0) {
+    // construct source lane_layout like [1, ..., 1, subgroupSize]
+    SmallVector<int> inferredLaneLayout(srcShapeSize);
+    SmallVector<int> inferredLaneData(srcShapeSize);
+    // remove the initial dims in resInstData to match srcShapeSize
+    for (int i = 0; i < srcShapeSize; i++) {
+      inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
+      inferredLaneData[i] = resLaneData[i + dimDiff];
+    }
+    return xegpu::LayoutAttr::get(context, inferredLaneLayout,
+                                  inferredLaneData);
+  }
+  return nullptr;
+}
+
 /// Infers the source layout attribute for a shape cast operation given the
 /// result layout attribute, result shape, and source shape.
 xegpu::DistributeLayoutAttr
@@ -335,12 +377,12 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
     auto resInstData = resLayout.getEffectiveInstDataAsInt();
     auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
     auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
-    if (resShapeSize == 2)
-      assert(resInstData[0] == 1 &&
-             "only innermost dim can have data combined from sources");
 
+    // get the layout info from the innermost dim of result layout
     if (resInstData.size() != 0) {
       SmallVector<int> inferredInstData(srcShapeSize, 1);
+      assert((resShapeSize == 1 || resInstData[0] == 1) &&
+             "only innermost dim can have data and instData layout");
       inferredInstData[srcShapeSize - 1] = resInstData[resShapeSize - 1];
       return xegpu::LayoutAttr::get(context, inferredInstData);
     }
@@ -348,6 +390,8 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
     if (resLaneLayout.size() != 0) {
       SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
       SmallVector<int> inferredLaneData(srcShapeSize, 1);
+      assert((resShapeSize == 1 || resLaneLayout[0] == 1) &&
+             "only innermost dim can have data and lane layout");
       inferredLaneLayout[srcShapeSize - 1] = resLaneLayout[resShapeSize - 1];
       inferredLaneData[srcShapeSize - 1] = resLaneData[resShapeSize - 1];
       return xegpu::LayoutAttr::get(context, inferredLaneLayout,
@@ -370,6 +414,10 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 /// It first tries to align the source layout's subgroup layout and data with
 /// the consumer's layout on non-reduction dimensions. Then, it distributes
 /// remaining subgroups across reduction dimensions.
+/// For InstData, it requries {1, ..., min(maxReduceVectorSize,
+/// srcshape),subgroupSize} For lane layout, it requires {1, ..., 1,
+/// subgroupSize} For lane data, it requires {1, ..., min(maxReduceVectorSize,
+/// srcshape), 1}
 
 xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy,
@@ -377,127 +425,99 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     const xegpu::uArch::uArch *uArch) {
 
   auto srcShape = srcVecTy.getShape();
-  xegpu::SliceAttr consumerSliceLayout =
-      dyn_cast<xegpu::SliceAttr>(consumerLayout);
-
-  int srcShapeSize = srcShape.size();
-  xegpu::DistributeLayoutAttr proposedSrcLayout;
+  int srcRank = srcShape.size();
   auto context = consumerLayout.getContext();
+
   // Reduction layout requires at least 2D tensors
-  if (srcShapeSize < 2)
+  if (srcRank < 2)
     return nullptr;
 
-  SmallVector<int64_t> sgLayout(srcShapeSize);
-  SmallVector<int64_t> sgData(srcShapeSize);
-  SmallVector<int64_t> instData(srcShapeSize);
-  SmallVector<int64_t> laneLayout(srcShapeSize);
-  SmallVector<int64_t> laneData(srcShapeSize);
-
-  // recover workgroup and subgroup size from consumer layout
-  DistributeLayoutAttr origPlainLayout;
-  if (consumerSliceLayout) {
-    origPlainLayout = consumerSliceLayout.flatten().getParent();
-  } else {
-    origPlainLayout = consumerLayout;
-  }
+  // Helper lambda to convert int64 vectors to int32 DenseArrayAttr
+  auto toInt32Attr = [&](ArrayRef<int64_t> vec) {
+    SmallVector<int32_t> vec32(vec.begin(), vec.end());
+    return DenseI32ArrayAttr::get(context, vec32);
+  };
 
-  const int workgroupSize =
-      std::accumulate(origPlainLayout.getEffectiveSgLayoutAsInt().begin(),
-                      origPlainLayout.getEffectiveSgLayoutAsInt().end(), 1,
-                      std::multiplies<int64_t>());
+  // Extract original plain layout for workgroup/subgroup size recovery
+  xegpu::SliceAttr consumerSliceLayout =
+      dyn_cast<xegpu::SliceAttr>(consumerLayout);
+  DistributeLayoutAttr plainLayout =
+      consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
+                          : consumerLayout;
 
+  auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
+  const int workgroupSize = std::accumulate(
+      sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
   const int subgroupSize = uArch->getSubgroupSize();
+  int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
 
-  const int vectorSize = 1; // vector size from SPRIV vector restriction
-
-  SmallVector<int64_t> defaultInstData(srcShapeSize, 1);
-
-  SmallVector<int64_t> defaultLaneLayout(srcShapeSize, 1);
-
-  SmallVector<int64_t> defaultLaneData(srcShapeSize, 1);
-  defaultInstData[srcShapeSize - 2] = vectorSize;
-  defaultInstData[srcShapeSize - 1] = subgroupSize;
-  defaultLaneLayout[srcShapeSize - 1] = subgroupSize;
-  defaultLaneData[srcShapeSize - 2] = vectorSize;
-  defaultLaneData[srcShapeSize - 1] = 1;
-
-  SmallVector<int64_t> consumerSgLayout =
-      consumerLayout.getEffectiveSgLayoutAsInt();
-  SmallVector<int64_t> consumerLaneLayout =
-      consumerLayout.getEffectiveLaneLayoutAsInt();
-  int remainingSgCount = workgroupSize;
-  int consumerSgId;
+  xegpu::DistributeLayoutAttr srcLayout;
 
   switch (layoutKind) {
-  case xegpu::LayoutKind::Subgroup:
-    consumerSgId = consumerSgLayout.size() - 1;
-    // For non-reduction dimensions, try to match consumer's sg_layout
-    // This ensures the result after reduction has the expected distribution
-    for (int i = srcShapeSize - 1; i >= 0; i--)
-      if (!llvm::is_contained(reductionDims, i) && consumerSgId >= 0) {
-        sgLayout[i] = consumerSgLayout[consumerSgId];
+  case xegpu::LayoutKind::Subgroup: {
+    SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
+    SmallVector<int64_t> consumerSgLayout =
+        consumerLayout.getEffectiveSgLayoutAsInt();
+    int remainingSgCount = workgroupSize;
+    int consumerIdx = consumerSgLayout.size() - 1;
+
+    // First pass: Match consumer's layout on non-reduction dimensions
+    for (int i = srcRank - 1; i >= 0; i--) {
+      if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
+        sgLayout[i] = consumerSgLayout[consumerIdx];
         assert((srcShape[i] % sgLayout[i] == 0) &&
                "source shape not divisible by consumer sg_layout");
         sgData[i] = srcShape[i] / sgLayout[i];
         remainingSgCount /= sgLayout[i];
-        consumerSgId--;
+        consumerIdx--;
       }
-    // Second pass: Distribute remaining subgroups across unhandled dimensions
-    // This handles reduction dimensions that don't necessarily align with
-    // consumer
-    for (int i = srcShapeSize - 1; i >= 0; i--)
+    }
+
+    // Second pass: Distribute remaining subgroups across reduction dimensions
+    for (int i = srcRank - 1; i >= 0; i--) {
       if (llvm::is_contained(reductionDims, i)) {
-        sgLayout[i] = std::min((srcShape[i] / defaultLaneLayout[i]),
+        sgLayout[i] = std::min(srcShape[i] / subgroupSize,
                                static_cast<int64_t>(remainingSgCount));
         assert((srcShape[i] % sgLayout[i] == 0) &&
-               "source shape not divisible by consumer sg_layout");
+               "source shape not divisible by sg_layout");
         sgData[i] = srcShape[i] / sgLayout[i];
         remainingSgCount /= sgLayout[i];
       }
-    assert(remainingSgCount == 1 && "not all subgroups have been distributed");
-    break;
-  case xegpu::LayoutKind::InstData:
-    for (int i = 0; i < srcShapeSize; i++) {
-      instData[i] = std::min(defaultInstData[i], srcShape[i]);
     }
+
+    assert(remainingSgCount == 1 && "not all subgroups distributed");
+    srcLayout =
+        xegpu::LayoutAttr::get(context, toInt32Attr(sgLayout),
+                               toInt32Attr(sgData), consumerLayout.getOrder());
     break;
-  case xegpu::LayoutKind::Lane:
-    laneLayout = defaultLaneLayout;
-    laneData = defaultLaneData;
-    break;
-  default:
-    llvm_unreachable("unsupported layout kind");
   }
 
-  SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
-  SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
-  SmallVector<int32_t> instData32(instData.begin(), instData.end());
-  SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
-  SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
-
-  switch (layoutKind) {
-  case xegpu::LayoutKind::Subgroup:
-    proposedSrcLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, sgLayout32),
-        DenseI32ArrayAttr::get(context, sgData32), consumerLayout.getOrder());
+  case xegpu::LayoutKind::InstData: {
+    SmallVector<int64_t> instData(srcRank, 1);
+    instData[srcRank - 2] =
+        std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
+    instData[srcRank - 1] = subgroupSize;
+    srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
     break;
-  case xegpu::LayoutKind::InstData:
-    proposedSrcLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, instData32));
-    break;
-  case xegpu::LayoutKind::Lane:
-    proposedSrcLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, laneLayout32),
-        DenseI32ArrayAttr::get(context, laneData32), consumerLayout.getOrder());
+  }
+
+  case xegpu::LayoutKind::Lane: {
+    SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
+    laneLayout[srcRank - 1] = subgroupSize;
+    laneData[srcRank - 2] =
+        std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
+    srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
+                                       toInt32Attr(laneData),
+                                       consumerLayout.getOrder());
     break;
+  }
+
   default:
     llvm_unreachable("unsupported layout kind");
   }
 
-  xegpu::SliceAttr resLayout =
-      xegpu::SliceAttr::get(context, proposedSrcLayout,
-                            DenseI64ArrayAttr::get(context, reductionDims));
-  return resLayout;
+  return xegpu::SliceAttr::get(context, srcLayout,
+                               DenseI64ArrayAttr::get(context, reductionDims));
 }
 
 xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
@@ -657,45 +677,6 @@ xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
   return requiredResLayout;
 }
 
-xegpu::DistributeLayoutAttr xegpu::inferInsertStridedSliceSourceLayout(
-    xegpu::DistributeLayoutAttr resLayout, ArrayRef<int64_t> resShape,
-    ArrayRef<int64_t> srcShape) {
-
-  int srcShapeSize = srcShape.size();
-  int resShapeSize = resShape.size();
-  int dimDiff = resShapeSize - srcShapeSize;
-
-  // assert resLayout must be a plain layout
-  assert(isa<xegpu::LayoutAttr>(resLayout) &&
-         "insertStridedSlice result layout must be plain layout");
-  auto context = resLayout.getContext();
-  auto resInstData = resLayout.getEffectiveInstDataAsInt();
-  auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
-  auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
-
-  if (resInstData.size() != 0) {
-    SmallVector<int> inferredInstData(srcShapeSize);
-    // remove the initial dims in resInstData to match srcShapeSize
-    for (int i = 0; i < srcShapeSize; i++)
-      inferredInstData[i] = resInstData[i + dimDiff];
-    return xegpu::LayoutAttr::get(context, inferredInstData);
-  }
-
-  if (resLaneLayout.size() != 0) {
-    // construct source lane_layout like [1, ..., 1, subgroupSize]
-    SmallVector<int> inferredLaneLayout(srcShapeSize);
-    SmallVector<int> inferredLaneData(srcShapeSize);
-    // remove the initial dims in resInstData to match srcShapeSize
-    for (int i = 0; i < srcShapeSize; i++) {
-      inferredLaneLayout[i] = resLaneLayout[i + dimDiff];
-      inferredLaneData[i] = resLaneData[i + dimDiff];
-    }
-    return xegpu::LayoutAttr::get(context, inferredLaneLayout,
-                                  inferredLaneData);
-  }
-  return nullptr;
-}
-
 static xegpu::DistributeLayoutAttr
 getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
                          const xegpu::uArch::uArch *uArch) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 6385995317f38..951c6e092daf2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1147,40 +1147,6 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
   return;
 }
 
-// /// For vector::BitCastOp, the lane_data of the source layout is changed
-// based
-// /// on the bit width of the source and result types.
-// void LayoutInfoPropagation::visitVectorInsertStridedSliceOp(
-//     vector::BitCastOp bitcast, ArrayRef<LayoutInfoLattice *> operands,
-//     ArrayRef<const LayoutInfoLattice *> results) {
-//   // Need the layout of bitcast result to propagate to the operands.
-//   LayoutInfo resLayoutInfo = results[0]->getValue();
-//   if (!resLayoutInfo.isAssigned())
-//     return;
-
-//   auto srcVecType = bitcast.getSourceVectorType();
-//   auto resVecType = bitcast.getResultVectorType();
-
-//   auto consumerLayoutAttr =
-//       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
-//   auto uArch = getUArch(xegpu::getChipStr(bitcast).value_or(""));
-//   auto requiredResLayoutAttr = setupInsertStridedSliceResultLayout(
-//       layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
-
-//   xegpu::setTemporaryLayout(bitcast->getResult(0), requiredResLayoutAttr);
-
-//   int inElemTyBitWidth = srcVecType.getElementType().getIntOrFloatBitWidth();
-//   int outElemTyBitWidth =
-//   resVecType.getElementType().getIntOrFloatBitWidth();
-
-//   // derive the source layout from the dominant layout and reduction dims
-//   auto srcLayoutAttr = xegpu::inferBitCastSourceLayout(
-//       requiredResLayoutAttr, outElemTyBitWidth, inElemTyBitWidth);
-
-//   propagateIfChanged(operands[0],
-//   operands[0]->meet(LayoutInfo(srcLayoutAttr)));
-// }
-
 /// Propagate the layout of the result to the tensor descriptor, mask and offset
 /// operands in LoadGatherOp.
 void LayoutInfoPropagation::visitLoadGatherOp(
@@ -1265,8 +1231,6 @@ void LayoutInfoPropagation::visitStoreScatterOp(
   auto uArch = getUArch(getChipStr(storeScatter).value_or(""));
   auto subgroupSize = uArch->getSubgroupSize();
   VectorType srcVecTy = storeScatter.getValueType();
-  VectorType maskTy =
-      llvm::dyn_cast<VectorType>(storeScatter.getMask().getType());
   int chunkSize = storeScatter.getChunkSize().value_or(1);
 
   if (hasParamsOfLayoutKind(anchorLayoutAttr)) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index cc278c6813fc9..7d06ff0ff0fb3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1615,11 +1615,6 @@ struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
           warpOp,
           "the source or result of shape_cast op lacks distribution layout");
 
-    // For rank reducing or increasing shape_cast ops, the lower rank layout
-    // must be a slice of higher rank layout.
-    int64_t sourceRank = shapeCastOp.getSourceVectorType().getRank();
-    int64_t resultRank = shapeCastOp.getResultVectorType().getRank();
-
     FailureOr<VectorType> sourceDistTypeOrFailure =
         getDistVecTypeBasedOnLaneLayout(sourceLayout,
                                         shapeCastOp.getSourceVectorType());

>From 97e9211a376ecb634a41b7e506a7a8cce8a60865 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 18:02:56 +0000
Subject: [PATCH 24/35] improve InsertStridedSlice result layout setup

---
 .../XeGPU/Transforms/XeGPULayoutImpls.h       |  3 +-
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 69 +++++++++++++------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  4 +-
 3 files changed, 52 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index e5f2b1be93cc2..27325987c64e7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -126,7 +126,8 @@ DistributeLayoutAttr setupBitCastResultLayout(
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
 DistributeLayoutAttr setupInsertStridedSliceResultLayout(
-    LayoutKind layoutKind, VectorType resVectorTy, const uArch::uArch *uArch);
+    LayoutKind layoutKind, VectorType resVectorTy,
+    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
 DistributeLayoutAttr
 setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 8509f709c6fec..1165d9d2cefab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -408,16 +408,17 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 /// Algorithm Overview:
 /// This function attempts to construct a source layout that, when sliced along
 /// reduction dimensions, produces a result layout compatible with the
-/// consumer layout. This minimizes data redistribution overhead
-/// between the reduction operation and its consumer.
+/// consumer layout.
 ///
-/// It first tries to align the source layout's subgroup layout and data with
-/// the consumer's layout on non-reduction dimensions. Then, it distributes
-/// remaining subgroups across reduction dimensions.
-/// For InstData, it requries {1, ..., min(maxReduceVectorSize,
-/// srcshape),subgroupSize} For lane layout, it requires {1, ..., 1,
-/// subgroupSize} For lane data, it requires {1, ..., min(maxReduceVectorSize,
-/// srcshape), 1}
+/// For subgroup layouts, it first tries to align the source layout's subgroup
+/// layout and data with the consumer's layout on non-reduction dimensions.
+/// Then, it distributes remaining subgroups across reduction dimensions. This
+/// avoid subgroup data redistribution overhead between the reduced result and
+/// its consumer.
+///
+/// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
+/// Lane Layout requires {1, ..., 1, subgroupSize}
+/// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
 
 xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy,
@@ -520,6 +521,11 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
                                DenseI64ArrayAttr::get(context, reductionDims));
 }
 
+/// Sets up the result layout for a bitcast operation.
+/// When casting to a smaller bitwidth, adjusts the layout dimensions (sgData,
+/// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
+/// result layout can be correctly divided back to the source layout during
+/// inference.
 xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
     DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
@@ -642,22 +648,37 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
   return requiredLayout;
 }
 
-xegpu::DistributeLayoutAttr
-xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
-                                           VectorType resVectorTy,
-                                           const xegpu::uArch::uArch *uArch) {
+/// Sets up the result layout for an insert strided slice operation.
+/// Creates a default layout based on the specified layout kind (InstData or
+/// Lane).
+/// Subgroup layout is currently not supported for this operation.
+/// InstData layout requires {1, .., subgroupSize} by default.
+/// Lane layout requires {1, ..., subgroupSize} with lane data {1, ..., 1}.
+/// The instData and laneData is adjusted to contain packed data, by checking if
+/// the consumerLayout's innermost dimension.
+xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
+    xegpu::LayoutKind layoutKind, VectorType resVectorTy,
+    xegpu::DistributeLayoutAttr consumerLayout,
+    const xegpu::uArch::uArch *uArch) {
 
   xegpu::DistributeLayoutAttr requiredResLayout;
   auto subgroupSize = uArch->getSubgroupSize();
   auto context = resVectorTy.getContext();
   auto resShape = resVectorTy.getShape();
   int resShapeSize = resShape.size();
+  SmallVector<int64_t> consumerInstData =
+      consumerLayout.getEffectiveInstDataAsInt();
+  SmallVector<int64_t> consumerLaneData =
+      consumerLayout.getEffectiveLaneDataAsInt();
+
+  SmallVector<int> instData(resShapeSize, 1);
+  SmallVector<int> laneLayout(resShapeSize, 1);
+  SmallVector<int> laneData(resShapeSize, 1);
 
-  SmallVector<int> defaultInstData(resShapeSize, 1);
-  SmallVector<int> defaultLaneLayout(resShapeSize, 1);
-  SmallVector<int> defaultLaneData(resShapeSize, 1);
-  defaultInstData[resShapeSize - 1] = subgroupSize;
-  defaultLaneLayout[resShapeSize - 1] = subgroupSize;
+  const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
+  unsigned bitwidth = resVectorTy.getElementType().getIntOrFloatBitWidth();
+  int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
+  int packedDataSize = subgroupSize * packingFactor;
 
   switch (layoutKind) {
   case xegpu::LayoutKind::Subgroup:
@@ -665,11 +686,17 @@ xegpu::setupInsertStridedSliceResultLayout(xegpu::LayoutKind layoutKind,
            "subgroup layout assignment not supported for insertStridedSlice.");
     break;
   case xegpu::LayoutKind::InstData:
-    requiredResLayout = xegpu::LayoutAttr::get(context, defaultInstData);
+    instData[resShapeSize - 1] = subgroupSize;
+    if (consumerInstData[resShapeSize - 1] == packedDataSize)
+      instData[resShapeSize - 1] = packedDataSize;
+    requiredResLayout = xegpu::LayoutAttr::get(context, instData);
     break;
   case xegpu::LayoutKind::Lane:
-    requiredResLayout =
-        xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
+    laneLayout[resShapeSize - 1] = subgroupSize;
+    laneData[resShapeSize - 1] = 1;
+    if (consumerLaneData[resShapeSize - 1] == packingFactor)
+      laneData[resShapeSize - 1] = packingFactor;
+    requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
     break;
   default:
     llvm_unreachable("unsupported layout kind");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 951c6e092daf2..15547b33765fd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1132,8 +1132,8 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
   auto uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
 
-  auto requiredResLayoutAttr =
-      xegpu::setupInsertStridedSliceResultLayout(layoutKind, resVecType, uArch);
+  auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
+      layoutKind, resVecType, consumerLayoutAttr, uArch);
 
   xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
                             requiredResLayoutAttr);

>From 44bd04fc50637023ac44fa8a1762f7b4f388811c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 19:08:22 +0000
Subject: [PATCH 25/35] polish loadgather layout setup

---
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 89 ++++++++++---------
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  2 +-
 2 files changed, 48 insertions(+), 43 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 1165d9d2cefab..664edbf8c0003 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -714,6 +714,17 @@ getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
   return xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1});
 }
 
+/// Sets up the result layout for a load gather operation.
+/// For Subgroup layout, uses the consumer layout directly.
+/// non-chunked loads:
+///   InstData = {1, ..., min(consumer, maxLaneLoadStoreSize *
+///   subgroupSize)}
+///   LaneLayout = {1, ..., subgroupSize}
+///   lane_data = {1, ..., min(consumer, maxLaneLoadStoreSize)}
+/// chunked loads:
+///   InstData = {subgroupSize, min(consumer, maxLaneLoadStoreSize)}
+///   LaneLayout = {subgroupSize, 1}
+///   lane_data={1,min(consumer, maxLaneLoadStoreSize)}
 xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
     LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
@@ -721,9 +732,7 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
   xegpu::DistributeLayoutAttr requiredLayout;
   const int subgroupSize = uArch->getSubgroupSize();
 
-  auto resShape = resVecTy.getShape();
-  int resShapeSize = resShape.size();
-  SmallVector<int> instData(resShapeSize);
+  int resShapeSize = resVecTy.getShape().size();
   auto context = resVecTy.getContext();
 
   const auto *uArchInstruction =
@@ -732,48 +741,44 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
 
   SmallVector<int64_t> consumerInstData =
       consumerLayout.getEffectiveInstDataAsInt();
-  SmallVector<int32_t> instData32;
+  SmallVector<int64_t> consumerLaneData =
+      consumerLayout.getEffectiveLaneDataAsInt();
 
-  switch (layoutKind) {
-  case xegpu::LayoutKind::Subgroup:
-    requiredLayout = consumerLayout;
-    break;
-  case xegpu::LayoutKind::InstData:
-    if (resVecTy.getRank() == 1) {
-      instData[0] = subgroupSize;
-    } else {
-      assert((resVecTy.getRank() == 2) && "StoreScatterOp can access 2D tensor "
-                                          "tile at maximum at subgroup level.");
-      if (chunkSize == 1) {
-        instData[0] = 1;
-        instData[1] = subgroupSize;
-      } else {
-        instData[0] = subgroupSize;
-        instData[1] = std::min(static_cast<int>(resShape[1]),
-                               uArchInstruction->getMaxLaneLoadStoreSize());
-        instData[1] =
-            std::min(instData[1], static_cast<int>(consumerInstData[1]));
-      }
+  SmallVector<int> instData(resShapeSize, 1);
+  SmallVector<int> laneLayout(resShapeSize, 1);
+  SmallVector<int> laneData(resShapeSize, 1);
+
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
+    return consumerLayout;
+  }
+
+  if (chunkSize == 1) {
+    if (layoutKind == xegpu::LayoutKind::InstData) {
+      instData[resShapeSize - 1] =
+          std::min(static_cast<int>(consumerInstData[resShapeSize - 1]),
+                   uArchInstruction->getMaxLaneLoadStoreSize() * subgroupSize);
+      requiredLayout = xegpu::LayoutAttr::get(context, instData);
+    } else if (layoutKind == xegpu::LayoutKind::Lane) {
+      laneLayout[resShapeSize - 1] = subgroupSize;
+      laneData[resShapeSize - 1] =
+          std::min(static_cast<int>(consumerLaneData[resShapeSize - 1]),
+                   uArchInstruction->getMaxLaneLoadStoreSize());
+      requiredLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
     }
-    requiredLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, instData));
-    break;
-  case xegpu::LayoutKind::Lane:
-    if (chunkSize == 1)
-      requiredLayout =
-          getDefaultLaneLayoutAttr(context, resVecTy.getRank(), uArch);
-    else {
-      assert((resVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
-                                          "tile at maximum at subgroup level.");
-      assert(resShape[1] <= static_cast<int64_t>(
-                                uArchInstruction->getMaxLaneLoadStoreSize()) &&
-             "StoreScatterOp lane size exceeds max lane load/store size.");
-      requiredLayout = xegpu::LayoutAttr::get(
-          context, {subgroupSize, 1}, {1, static_cast<int>(resShape[1])});
+  } else {
+    assert(resVecTy.getRank() == 2 &&
+           "Chunked Store must access 2D tensor tile.");
+    if (layoutKind == xegpu::LayoutKind::InstData) {
+      instData[0] = subgroupSize;
+      instData[1] = std::min(static_cast<int>(consumerInstData[1]),
+                             uArchInstruction->getMaxLaneLoadStoreSize());
+      requiredLayout = xegpu::LayoutAttr::get(context, instData);
+    } else if (layoutKind == xegpu::LayoutKind::Lane) {
+      laneLayout[0] = subgroupSize;
+      laneData[1] = std::min(static_cast<int>(consumerLaneData[1]),
+                             uArchInstruction->getMaxLaneLoadStoreSize());
+      requiredLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
     }
-    break;
-  default:
-    llvm_unreachable("unsupported layout kind");
   }
   return requiredLayout;
 }
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index a8eccfab53c37..85b9b82179e57 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -107,7 +107,7 @@ gpu.module @test {
 // CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
 // CHECK-SAME:  dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
 // CHECK-NEXT: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK-NEXT: %{{.*}} = xegpu.load %arg1[%[[OFFSET]]], %[[MASK]] <{chunk_size = 16 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 16]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
+// CHECK-NEXT: %{{.*}} = xegpu.load %arg1[%[[OFFSET]]], %[[MASK]] <{chunk_size = 16 : i64, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x16xf16>
 func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
   %c0 = arith.constant 0 : index
   %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>

>From 6feb7c824ed3fec6b598de5fd3522bd4e36be01e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 30 Jan 2026 23:33:15 +0000
Subject: [PATCH 26/35] polish load store layout setup

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    |  30 +-
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      |   8 +-
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 296 +++++++++---------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  21 --
 4 files changed, 175 insertions(+), 180 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 29e75b57f4a5f..0f138b9defb5c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -237,6 +237,28 @@ struct LoadGatherInstruction : public Instruction {
   int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
+struct StoreMatrixInstruction : public Instruction {
+  StoreMatrixInstruction()
+      : Instruction(InstructionKind::StoreMatrix, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::StoreMatrix;
+  }
+
+  // SPIRV restricts vector size
+  int32_t getMaxLaneLoadStoreSize() const { return 16; }
+};
+
+struct LoadMatrixInstruction : public Instruction {
+  LoadMatrixInstruction()
+      : Instruction(InstructionKind::LoadMatrix, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::LoadMatrix;
+  }
+
+  // SPIRV restricts vector size
+  int32_t getMaxLaneLoadStoreSize() const { return 16; }
+};
+
 //===----------------------------------------------------------------------===//
 // uArch instances
 //===----------------------------------------------------------------------===//
@@ -249,9 +271,11 @@ struct PVCuArch final : public Xe2Plus {
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
     static const StoreScatterInstruction storeScatterInst;
     static const LoadGatherInstruction loadGatherInst;
-    static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
-                                       &storeNdInst,      &prefetchNdInst,
-                                       &storeScatterInst, &loadGatherInst};
+    static const StoreMatrixInstruction storeMatrixInst;
+    static const LoadMatrixInstruction loadMatrixInst;
+    static const Instruction *arr[] = {
+        &dpasInst,         &loadNdInst,     &storeNdInst,     &prefetchNdInst,
+        &storeScatterInst, &loadGatherInst, &storeMatrixInst, &loadMatrixInst};
     return arr;
   }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index db1984b2edb1d..75a483baa116d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -40,7 +40,9 @@ enum class InstructionKind {
   Subgroup2DBlockLoad,       // Subgroup-level 2D block load instruction
   Subgroup2DBlockPrefetch,   // Subgroup-level 2D block prefetch instruction
   StoreScatter,              // Lane-level store (scalar, vector)
-  LoadGather                 // Lane-level load (scalar, vector)
+  LoadGather,                // Lane-level load (scalar, vector)
+  StoreMatrix,               // Lane-level matrix store to slm
+  LoadMatrix                 // Lane-level matrix load to slm
   // @TODO: Add more instructions as needed
 };
 
@@ -71,6 +73,10 @@ struct Instruction {
       return "store";
     case InstructionKind::LoadGather:
       return "load";
+    case InstructionKind::StoreMatrix:
+      return "store_matrix";
+    case InstructionKind::LoadMatrix:
+      return "load_matrix";
     }
     llvm_unreachable("Unknown InstructionKind");
   }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 664edbf8c0003..a9428be4cadeb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -588,66 +588,6 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
   return consumerLayout;
 }
 
-xegpu::DistributeLayoutAttr
-xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
-                                   VectorType resVectorTy,
-                                   xegpu::DistributeLayoutAttr consumerLayout,
-                                   const xegpu::uArch::uArch *uArch) {
-  xegpu::DistributeLayoutAttr requiredLayout;
-  auto subgroupSize = uArch->getSubgroupSize();
-  SmallVector<int> defaultInstData = {1, subgroupSize};
-  SmallVector<int> defaultLaneLayout = {1, subgroupSize};
-  SmallVector<int> defaultLaneData = {1, 1};
-  auto context = resVectorTy.getContext();
-
-  switch (layoutKind) {
-  case xegpu::LayoutKind::Subgroup:
-    requiredLayout = consumerLayout;
-    break;
-  case xegpu::LayoutKind::InstData:
-    requiredLayout = xegpu::LayoutAttr::get(context, defaultInstData);
-    break;
-  case xegpu::LayoutKind::Lane:
-    requiredLayout =
-        xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
-    break;
-  default:
-    llvm_unreachable("unsupported layout kind");
-  }
-  return requiredLayout;
-}
-
-xegpu::DistributeLayoutAttr
-xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
-                                    VectorType srcVectorTy,
-                                    const xegpu::uArch::uArch *uArch) {
-
-  xegpu::DistributeLayoutAttr requiredLayout;
-  auto subgroupSize = uArch->getSubgroupSize();
-  SmallVector<int> defaultInstData = {1, subgroupSize};
-  SmallVector<int> defaultLaneLayout = {1, subgroupSize};
-  SmallVector<int> defaultLaneData = {1, 1};
-  auto context = srcVectorTy.getContext();
-
-  switch (layoutKind) {
-  case xegpu::LayoutKind::Subgroup:
-    assert(true &&
-           "subgroup layout assignment not supported yet for storeMatrix.");
-    break;
-  case xegpu::LayoutKind::InstData:
-    requiredLayout = xegpu::LayoutAttr::get(context, defaultInstData);
-    break;
-  case xegpu::LayoutKind::Lane:
-    requiredLayout =
-        xegpu::LayoutAttr::get(context, defaultLaneLayout, defaultLaneData);
-
-    break;
-  default:
-    llvm_unreachable("unsupported layout kind");
-  }
-  return requiredLayout;
-}
-
 /// Sets up the result layout for an insert strided slice operation.
 /// Creates a default layout based on the specified layout kind (InstData or
 /// Lane).
@@ -704,141 +644,187 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
   return requiredResLayout;
 }
 
-static xegpu::DistributeLayoutAttr
-getDefaultLaneLayoutAttr(mlir::MLIRContext *ctx, unsigned rank,
-                         const xegpu::uArch::uArch *uArch) {
-  assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
-  if (rank == 1) {
-    return xegpu::LayoutAttr::get(ctx, {uArch->getSubgroupSize()}, {1});
-  }
-  return xegpu::LayoutAttr::get(ctx, {1, uArch->getSubgroupSize()}, {1, 1});
-}
-
-/// Sets up the result layout for a load gather operation.
+/// Sets up the anchor layout for load gather and load matrix operation.
+/// load matrix lowers to load gather and 1d block load. All of them share the
+/// same layout setup logic.
 /// For Subgroup layout, uses the consumer layout directly.
 /// non-chunked loads:
-///   InstData = {1, ..., min(consumer, maxLaneLoadStoreSize *
-///   subgroupSize)}
+///   InstData = {1, ..., min(consumer, maxLaneLoadStoreSize * subgroupSize)}
 ///   LaneLayout = {1, ..., subgroupSize}
 ///   lane_data = {1, ..., min(consumer, maxLaneLoadStoreSize)}
 /// chunked loads:
 ///   InstData = {subgroupSize, min(consumer, maxLaneLoadStoreSize)}
 ///   LaneLayout = {subgroupSize, 1}
 ///   lane_data={1,min(consumer, maxLaneLoadStoreSize)}
+static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
+    xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
+    xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
+    int maxChunkSize, int valShapeSize, int subgroupSize) {
+  SmallVector<int64_t> consumerInstData =
+      consumerLayout.getEffectiveInstDataAsInt();
+  SmallVector<int64_t> consumerLaneData =
+      consumerLayout.getEffectiveLaneDataAsInt();
+
+  SmallVector<int> instData(valShapeSize, 1);
+  SmallVector<int> laneLayout(valShapeSize, 1);
+  SmallVector<int> laneData(valShapeSize, 1);
+
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
+    return consumerLayout;
+  }
+
+  if (!isChunkedLoad) {
+    if (layoutKind == xegpu::LayoutKind::InstData) {
+      instData[valShapeSize - 1] =
+          std::min(static_cast<int>(consumerInstData[valShapeSize - 1]),
+                   maxChunkSize * subgroupSize);
+      return xegpu::LayoutAttr::get(context, instData);
+    } else if (layoutKind == xegpu::LayoutKind::Lane) {
+      laneLayout[valShapeSize - 1] = subgroupSize;
+      laneData[valShapeSize - 1] = std::min(
+          static_cast<int>(consumerLaneData[valShapeSize - 1]), maxChunkSize);
+      return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+    }
+  } else {
+    assert(valShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
+    if (layoutKind == xegpu::LayoutKind::InstData) {
+      instData[0] = subgroupSize;
+      instData[1] =
+          std::min(static_cast<int>(consumerInstData[1]), maxChunkSize);
+      return xegpu::LayoutAttr::get(context, instData);
+    } else if (layoutKind == xegpu::LayoutKind::Lane) {
+      laneLayout[0] = subgroupSize;
+      laneData[1] =
+          std::min(static_cast<int>(consumerLaneData[1]), maxChunkSize);
+      return xegpu::LayoutAttr::get(context, laneLayout, laneData);
+    }
+  }
+  return nullptr;
+}
+
+/// Sets up the anchor layout for a load gather operation.
 xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
-    LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
-    DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
+    xegpu::LayoutKind layoutKind, VectorType resVecTy, int chunkSize,
+    xegpu::DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch) {
 
-  xegpu::DistributeLayoutAttr requiredLayout;
   const int subgroupSize = uArch->getSubgroupSize();
+  int resShapeSize = resVecTy.getShape().size();
+  auto context = resVecTy.getContext();
+
+  const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadGatherInstruction>(
+      uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
+  int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+
+  return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
+                                      (chunkSize > 1), maxChunkSize,
+                                      resShapeSize, subgroupSize);
+}
+
+/// Sets up the anchor layout for load matrix operation.
+/// TODO: enhance load matrix to indicate lowering to chunked load or not.
+xegpu::DistributeLayoutAttr
+xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
+                                   VectorType resVecTy,
+                                   xegpu::DistributeLayoutAttr consumerLayout,
+                                   const xegpu::uArch::uArch *uArch) {
 
+  const int subgroupSize = uArch->getSubgroupSize();
   int resShapeSize = resVecTy.getShape().size();
   auto context = resVecTy.getContext();
 
-  const auto *uArchInstruction =
-      dyn_cast<xegpu::uArch::StoreScatterInstruction>(
-          uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+  const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadMatrixInstruction>(
+      uArch->getInstruction(xegpu::uArch::InstructionKind::LoadMatrix));
+  int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
 
-  SmallVector<int64_t> consumerInstData =
-      consumerLayout.getEffectiveInstDataAsInt();
-  SmallVector<int64_t> consumerLaneData =
-      consumerLayout.getEffectiveLaneDataAsInt();
+  return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
+                                      false, maxChunkSize, resShapeSize,
+                                      subgroupSize);
+}
 
-  SmallVector<int> instData(resShapeSize, 1);
-  SmallVector<int> laneLayout(resShapeSize, 1);
-  SmallVector<int> laneData(resShapeSize, 1);
+/// Sets up the anchor layout for store scatter and store matrix operation.
+/// store matrix lowers to store scatter and 1d block store. All of them share
+/// the same layout setup logic. For Subgroup layout, not support yet.
+/// non-chunked stores:
+///   InstData = {1, ..., subgroupSize}
+///   LaneLayout = {1, ..., subgroupSize}
+///   lane_data = {1, ..., 1}
+/// chunked stores:
+///   InstData = {subgroupSize, min(srcVec, maxLaneLoadStoreSize)}
+///   LaneLayout = {subgroupSize, 1}
+///   lane_data={1,min(srcVec, maxLaneLoadStoreSize)}
+static xegpu::DistributeLayoutAttr
+setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind,
+                              mlir::MLIRContext *context, bool isChunkedStore,
+                              int maxChunkSize, ArrayRef<int64_t> srcShape,
+                              int subgroupSize) {
+
+  int srcShapeSize = srcShape.size();
+  SmallVector<int> instData(srcShapeSize, 1);
+  SmallVector<int> laneLayout(srcShapeSize, 1);
+  SmallVector<int> laneData(srcShapeSize, 1);
 
   if (layoutKind == xegpu::LayoutKind::Subgroup) {
-    return consumerLayout;
+    assert(true &&
+           "subgroup layout assignment not supported for storeScatter.");
+    return nullptr;
   }
 
-  if (chunkSize == 1) {
+  if (!isChunkedStore) {
     if (layoutKind == xegpu::LayoutKind::InstData) {
-      instData[resShapeSize - 1] =
-          std::min(static_cast<int>(consumerInstData[resShapeSize - 1]),
-                   uArchInstruction->getMaxLaneLoadStoreSize() * subgroupSize);
-      requiredLayout = xegpu::LayoutAttr::get(context, instData);
+      instData[srcShapeSize - 1] = subgroupSize;
+      return xegpu::LayoutAttr::get(context, instData);
     } else if (layoutKind == xegpu::LayoutKind::Lane) {
-      laneLayout[resShapeSize - 1] = subgroupSize;
-      laneData[resShapeSize - 1] =
-          std::min(static_cast<int>(consumerLaneData[resShapeSize - 1]),
-                   uArchInstruction->getMaxLaneLoadStoreSize());
-      requiredLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
+      laneLayout[srcShapeSize - 1] = subgroupSize;
+      return xegpu::LayoutAttr::get(context, laneLayout, laneData);
     }
   } else {
-    assert(resVecTy.getRank() == 2 &&
-           "Chunked Store must access 2D tensor tile.");
+    assert(srcShapeSize == 2 && "Chunked Store must access 2D tensor tile.");
     if (layoutKind == xegpu::LayoutKind::InstData) {
       instData[0] = subgroupSize;
-      instData[1] = std::min(static_cast<int>(consumerInstData[1]),
-                             uArchInstruction->getMaxLaneLoadStoreSize());
-      requiredLayout = xegpu::LayoutAttr::get(context, instData);
+      instData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
+      return xegpu::LayoutAttr::get(context, instData);
     } else if (layoutKind == xegpu::LayoutKind::Lane) {
       laneLayout[0] = subgroupSize;
-      laneData[1] = std::min(static_cast<int>(consumerLaneData[1]),
-                             uArchInstruction->getMaxLaneLoadStoreSize());
-      requiredLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
+      laneData[1] = std::min(static_cast<int>(srcShape[1]), maxChunkSize);
+      return xegpu::LayoutAttr::get(context, laneLayout, laneData);
     }
   }
-  return requiredLayout;
+  return nullptr;
 }
 
+/// Sets up the anchor layout for a store scatter operation.
 xegpu::DistributeLayoutAttr
-xegpu::setupStoreScatterAnchorLayout(LayoutKind layoutKind, VectorType srcVecTy,
-                                     int chunkSize, const uArch::uArch *uArch) {
+xegpu::setupStoreScatterAnchorLayout(xegpu::LayoutKind layoutKind,
+                                     VectorType srcVecTy, int chunkSize,
+                                     const uArch::uArch *uArch) {
 
-  xegpu::DistributeLayoutAttr requiredLayout;
   const int subgroupSize = uArch->getSubgroupSize();
-
-  auto srcShape = srcVecTy.getShape();
-  int srcShapeSize = srcShape.size();
-  SmallVector<int> instData(srcShapeSize);
+  ArrayRef<int64_t> srcShape = srcVecTy.getShape();
+  auto context = srcVecTy.getContext();
 
   const auto *uArchInstruction =
       dyn_cast<xegpu::uArch::StoreScatterInstruction>(
           uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
+  int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+
+  return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
+                                       maxChunkSize, srcShape, subgroupSize);
+}
+
+/// Sets up the anchor layout for a store matrix operation.
+xegpu::DistributeLayoutAttr
+xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
+                                    VectorType srcVecTy,
+                                    const xegpu::uArch::uArch *uArch) {
+
+  const int subgroupSize = uArch->getSubgroupSize();
+  ArrayRef<int64_t> srcShape = srcVecTy.getShape();
   auto context = srcVecTy.getContext();
 
-  switch (layoutKind) {
-  case xegpu::LayoutKind::Subgroup:
-    assert(
-        true &&
-        "subgroup layout assignment not supported yet for store scatter op.");
-    break;
-  case xegpu::LayoutKind::InstData:
-    if (srcVecTy.getRank() == 1) {
-      instData[0] = subgroupSize;
-    } else {
-      assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
-                                          "tile at maximum at subgroup level.");
-      if (chunkSize == 1) {
-        instData[0] = 1;
-        instData[1] = subgroupSize;
-      } else {
-        instData[0] = subgroupSize;
-        instData[1] = std::min(static_cast<int>(srcShape[1]),
-                               uArchInstruction->getMaxLaneLoadStoreSize());
-      }
-    }
-    requiredLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, instData));
-    break;
-  case xegpu::LayoutKind::Lane:
-    if (chunkSize == 1)
-      requiredLayout =
-          getDefaultLaneLayoutAttr(context, srcVecTy.getRank(), uArch);
-    else {
-      assert((srcVecTy.getRank() <= 2) && "StoreScatterOp can access 2D tensor "
-                                          "tile at maximum at subgroup level.");
-      assert(srcShape[1] <= static_cast<int64_t>(
-                                uArchInstruction->getMaxLaneLoadStoreSize()) &&
-             "StoreScatterOp lane size exceeds max lane load/store size.");
-      requiredLayout = xegpu::LayoutAttr::get(
-          context, {subgroupSize, 1}, {1, static_cast<int>(srcShape[1])});
-    }
-    break;
-  default:
-    llvm_unreachable("unsupported layout kind");
-  }
-  return requiredLayout;
+  const auto *uArchInstruction = dyn_cast<xegpu::uArch::StoreMatrixInstruction>(
+      uArch->getInstruction(xegpu::uArch::InstructionKind::StoreMatrix));
+  int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+
+  return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
+                                       srcShape, subgroupSize);
 }
\ No newline at end of file
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 15547b33765fd..a445df492bf1e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -308,27 +308,6 @@ static LayoutInfo getSIMTLayoutInfoBlockIO(Ty ty,
       ty.getContext(), {1, uArch->getSubgroupSize()}, {1, packingFactor}));
 }
 
-/// Helper to get the default layout for a vector type.
-static LayoutInfo getSIMTLayoutInfoScatterIO(VectorType vectorTy,
-                                             const xegpu::uArch::uArch *uArch) {
-  // Expecting a 1D or 2D vector.
-  assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
-         "Expected 1D or 2D vector.");
-  // Expecting int or float element type.
-  assert(vectorTy.getElementType().isIntOrFloat() &&
-         "Expected int or float element type.");
-  // If the rank is 1, then return default layout for 1D vector.
-  const unsigned packingSize{uArch->getGeneralPackedFormatBitSize()};
-  if (vectorTy.getRank() == 1)
-    return getDefaultSIMTLayoutInfo(vectorTy.getContext(), 1, uArch);
-  // Packing factor is determined by the element type bitwidth.
-  unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
-  int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
-  return LayoutInfo(xegpu::LayoutAttr::get(vectorTy.getContext(),
-                                           {uArch->getSubgroupSize(), 1},
-                                           {1, packingFactor}));
-}
-
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
 /// is set according to the following criteria:
 /// * For A operand, the data must be packed in minimum

>From e1f27e3b33097214f60d913dbd8d4c3d23a5d0db Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 31 Jan 2026 01:15:52 +0000
Subject: [PATCH 27/35] add test for create_memdesc and alloca

---
 .../XeGPU/subgroup-distribute-unit.mlir       | 29 +++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 3a978f68ad9c0..4a6c81d1309c0 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -483,7 +483,36 @@ gpu.func @memref_extract_aligned_pointer_as_index(%arg0 : memref<256x256xf16>, %
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @memref_alloca(
+// CHECK-NEXT:    %[[ALLOCA:.*]] = memref.alloca() : memref<2048xi8, 3>
+// CHECK-NEXT:    %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ALLOCA]] : memref<2048xi8, 3> -> index
+// CHECK-NEXT:    %[[CAST:.*]] = arith.index_cast %[[INTPTR]] : index to i64
+gpu.func @memref_alloca(%laneid: index) {
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (memref<2048xi8, 3>) {
+    %alloca = memref.alloca() : memref<2048xi8, 3>
+    gpu.yield %alloca :  memref<2048xi8, 3>
+  }
+  %ptr = memref.extract_aligned_pointer_as_index %r : memref<2048xi8, 3> -> index
+  %ptr_i64 = arith.index_cast %ptr : index to i64
+  "some_user_op"(%ptr_i64) : (i64) -> ()
+  gpu.return
+}
 
+// CHECK-LABEL: gpu.func @create_memdesc(
+// CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (!xegpu.mem_desc<4x128xf32>, memref<2048xi8, 3>) {
+// CHECK:         gpu.yield %{{.*}}, %{{.*}} : !xegpu.mem_desc<4x128xf32>, memref<2048xi8, 3>
+// CHECK-NEXT:  }
+// CHECK-NEXT:  %[[MDesc:.*]] = xegpu.create_mem_desc %[[W]]#1 : memref<2048xi8, 3> -> !xegpu.mem_desc<4x128xf32>
+gpu.func @create_memdesc(%laneid: index, %arg0 : memref<2048xi8, 3>) {
+  %c0 = arith.constant 0 : index
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (!xegpu.mem_desc<4x128xf32>) {
+    %mdesc = xegpu.create_mem_desc %arg0 : memref<2048xi8, 3> -> !xegpu.mem_desc<4x128xf32>
+    gpu.yield %mdesc :  !xegpu.mem_desc<4x128xf32>
+  }
+  %25 = xegpu.load_matrix  %r[%c0, %c0]: !xegpu.mem_desc<4x128xf32>, index, index -> vector<1x16xf32>
+  "some_user_op"(%25) : (vector<1x16xf32>) -> ()
+  gpu.return
+}
 
 // CHECK-LABEL: gpu.func @vector_transpose(
 // CHECK:       %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2x1xf32>, vector<1x2xf32>) {

>From bea8f8994ba1f31dc4b90bde1607daf271af85ed Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 2 Feb 2026 19:26:43 +0000
Subject: [PATCH 28/35] remove alloca and create_memdesc pattern

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 60 +------------------
 1 file changed, 1 insertion(+), 59 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index fe0042fc3827f..75875c60f1012 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1896,63 +1896,6 @@ struct MemrefExtractAlignedPointerAsIndexDistribution final
   }
 };
 
-struct MemrefAllocaDistribution final : public gpu::WarpDistributionPattern {
-  using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
-                                PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
-      // Check if the yield operand that was produced by the *last* scattered
-      // load op to avoid creating multiple copies due to multiple users.
-      return llvm::IsaPred<memref::AllocaOp>(op) &&
-             warpOp.getTerminator()->getPrevNode() == op;
-    });
-    if (!operand)
-      return rewriter.notifyMatchFailure(
-          warpOp, "warp result is not a memref::Alloca op");
-    auto allocaOp = operand->get().getDefiningOp<memref::AllocaOp>();
-    unsigned operandIdx = operand->getOperandNumber();
-    SmallVector<size_t> newRetIndices;
-    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, ValueRange{}, TypeRange{}, newRetIndices);
-    rewriter.setInsertionPointAfter(newWarpOp);
-    auto newAllocaOp = memref::AllocaOp::create(rewriter, newWarpOp.getLoc(),
-                                                allocaOp.getType(), nullptr);
-    Value resultVal = newWarpOp.getResult(operandIdx);
-    rewriter.replaceAllUsesWith(resultVal, newAllocaOp.getResult());
-    return success();
-  }
-};
-
-struct CreateMemDescDistribution final : public gpu::WarpDistributionPattern {
-  using gpu::WarpDistributionPattern::WarpDistributionPattern;
-  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
-                                PatternRewriter &rewriter) const override {
-    OpOperand *operand = getWarpResult(warpOp, [&](Operation *op) {
-      // Check if the yield operand that was produced by the *last* scattered
-      // load op to avoid creating multiple copies due to multiple users.
-      return llvm::IsaPred<xegpu::CreateMemDescOp>(op) &&
-             warpOp.getTerminator()->getPrevNode() == op;
-    });
-    if (!operand)
-      return rewriter.notifyMatchFailure(
-          warpOp, "warp result is not a xegpu::CreateMemDesc op");
-    auto createMemDescOp =
-        operand->get().getDefiningOp<xegpu::CreateMemDescOp>();
-    unsigned operandIdx = operand->getOperandNumber();
-    SmallVector<size_t> newRetIndices;
-    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, createMemDescOp.getSource(),
-        TypeRange{createMemDescOp.getSource().getType()}, newRetIndices);
-    rewriter.setInsertionPointAfter(newWarpOp);
-    auto newCreateMemDescOp = xegpu::CreateMemDescOp::create(
-        rewriter, newWarpOp.getLoc(), createMemDescOp.getType(),
-        newWarpOp.getResult(newRetIndices[0]));
-    Value resultVal = newWarpOp.getResult(operandIdx);
-    rewriter.replaceAllUsesWith(resultVal, newCreateMemDescOp.getResult());
-    return success();
-  }
-};
-
 /// Distribute a vector::BitCastOp feeding into yield op of an enclosing
 /// `gpu.warp_execute_on_lane_0` region. Bitcast only impacts the innermost
 /// diemension of the source/result vectors. Equivalent vector::BitCastOp is
@@ -2076,8 +2019,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
                LoadDistribution, StoreDistribution, VectorTransposeDistribution,
                VectorBitcastDistribution, LoadMatrixDistribution,
                StoreMatrixDistribution,
-               MemrefExtractAlignedPointerAsIndexDistribution,
-               MemrefAllocaDistribution, CreateMemDescDistribution>(
+               MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/PatternHierarchy::Regular);
   // For following patterns, we need to override the regular vector distribution

>From 9b96365ada303d1819100847aeb5902ca23007b8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 2 Feb 2026 23:08:57 +0000
Subject: [PATCH 29/35] remove switch default

---
 .../XeGPU/Transforms/XeGPULayoutImpls.h       | 31 ++++++----
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 58 +++++++------------
 2 files changed, 39 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index a4aebb3eb7cac..b7b7cf9dafd38 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -1,5 +1,4 @@
-//===- XeGPULayoutUtils.h - Layout Utilities --------------------------*- C++
-//-*-===//
+//===- XeGPULayoutImpls.h - Layout utility functions ------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -7,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
-#define MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTUTILS_H_
+#ifndef MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTIMPLS_H_
+#define MLIR_DIALECT_XEGPU_UTILS_XEGPULAYOUTIMPLS_H_
 
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -132,28 +131,36 @@ DistributeLayoutAttr setupBitCastResultLayout(
     LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
+/// Sets up the result layout for an insert strided slice operation.
+/// Creates a result layout based on the specified layout kind (InstData or
+/// Lane).
 DistributeLayoutAttr setupInsertStridedSliceResultLayout(
     LayoutKind layoutKind, VectorType resVectorTy,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
+/// Sets up the anchor layout for a load gather operation.
 DistributeLayoutAttr
-setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
-                            DistributeLayoutAttr consumerLayout,
+setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+                            int chunkSize, DistributeLayoutAttr consumerLayout,
                             const uArch::uArch *uArch);
 
-DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
-                                                  VectorType vectorTy,
-                                                  const uArch::uArch *uArch);
-
+/// Sets up the anchor layout for load matrix operation.
 DistributeLayoutAttr
-setupLoadGatherAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
-                            int chunkSize, DistributeLayoutAttr consumerLayout,
+setupLoadMatrixAnchorLayout(LayoutKind layoutKind, VectorType vectorTy,
+                            DistributeLayoutAttr consumerLayout,
                             const uArch::uArch *uArch);
 
+/// Sets up the anchor layout for a store scatter operation.
 DistributeLayoutAttr setupStoreScatterAnchorLayout(LayoutKind layoutKind,
                                                    VectorType vectorTy,
                                                    int chunkSize,
                                                    const uArch::uArch *uArch);
+
+/// Sets up the anchor layout for a store matrix operation.
+DistributeLayoutAttr setupStoreMatrixAnchorLayout(LayoutKind layoutKind,
+                                                  VectorType vectorTy,
+                                                  const uArch::uArch *uArch);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index a9428be4cadeb..4732da6db6f7a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -1,4 +1,5 @@
-//===---- XeGPUUtils.cpp - MLIR Utilities for XeGPUOps   ------------------===//
+//===---- XeGPULayoutImpls.cpp - MLIR Utilities for XeGPUOps
+//------------------===//
 //
 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +7,8 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements utility methods for working with the XeGPU dialect.
+// This file implements layout utility functions for XeGPU dialect
+// transformation.
 //
 //===----------------------------------------------------------------------===//
 
@@ -454,8 +456,7 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
 
   xegpu::DistributeLayoutAttr srcLayout;
 
-  switch (layoutKind) {
-  case xegpu::LayoutKind::Subgroup: {
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
     SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
     SmallVector<int64_t> consumerSgLayout =
         consumerLayout.getEffectiveSgLayoutAsInt();
@@ -490,19 +491,17 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     srcLayout =
         xegpu::LayoutAttr::get(context, toInt32Attr(sgLayout),
                                toInt32Attr(sgData), consumerLayout.getOrder());
-    break;
-  }
 
-  case xegpu::LayoutKind::InstData: {
+  } else if (layoutKind == xegpu::LayoutKind::InstData) {
+
     SmallVector<int64_t> instData(srcRank, 1);
     instData[srcRank - 2] =
         std::min(maxReduceVectorSize, srcShape[srcRank - 2]);
     instData[srcRank - 1] = subgroupSize;
     srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(instData));
-    break;
-  }
 
-  case xegpu::LayoutKind::Lane: {
+  } else if (layoutKind == xegpu::LayoutKind::Lane) {
+
     SmallVector<int64_t> laneLayout(srcRank, 1), laneData(srcRank, 1);
     laneLayout[srcRank - 1] = subgroupSize;
     laneData[srcRank - 2] =
@@ -510,11 +509,6 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     srcLayout = xegpu::LayoutAttr::get(context, toInt32Attr(laneLayout),
                                        toInt32Attr(laneData),
                                        consumerLayout.getOrder());
-    break;
-  }
-
-  default:
-    llvm_unreachable("unsupported layout kind");
   }
 
   return xegpu::SliceAttr::get(context, srcLayout,
@@ -550,13 +544,11 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
     // source layout.
     int bitWidthRatio = srcElemTyBitWidth / resElemTyBitWidth;
     int innermostDimLaneLayout = subgroupSize;
-    switch (layoutKind) {
-    case xegpu::LayoutKind::Subgroup:
+    if (layoutKind == xegpu::LayoutKind::Subgroup) {
       assert(sgData.size() == srcShape.size() &&
              "sgData must be available for all dimensions");
       sgDataValue = sgData[dim];
-      break;
-    case xegpu::LayoutKind::InstData:
+    } else if (layoutKind == xegpu::LayoutKind::InstData) {
       assert(instData.size() == srcShape.size() &&
              "instData must be available for all dimensions");
       instDataValue = instData[dim];
@@ -567,17 +559,13 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
         instDataValue *= 2;
       assert((srcShape[dim] % instDataValue) == 0 &&
              "srcShape, instData, and lanelayout for innermost must be 2^n !");
-      break;
-    case xegpu::LayoutKind::Lane:
+    } else if (layoutKind == xegpu::LayoutKind::Lane) {
       assert(laneData.size() == srcShape.size() &&
              "laneData must be available for all dimensions");
       laneDataValue = laneData[dim];
       while ((laneDataValue <= srcShape[dim]) &&
              (laneDataValue % bitWidthRatio != 0))
         laneDataValue *= 2;
-      break;
-    default:
-      llvm_unreachable("unsupported layout kind");
     }
     // Now set only instData and laneData, preserving sgData
     xegpu::DistributeLayoutAttr resLayout;
@@ -589,13 +577,13 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
 }
 
 /// Sets up the result layout for an insert strided slice operation.
-/// Creates a default layout based on the specified layout kind (InstData or
+/// Creates a result layout based on the specified layout kind (InstData or
 /// Lane).
 /// Subgroup layout is currently not supported for this operation.
-/// InstData layout requires {1, .., subgroupSize} by default.
-/// Lane layout requires {1, ..., subgroupSize} with lane data {1, ..., 1}.
-/// The instData and laneData is adjusted to contain packed data, by checking if
-/// the consumerLayout's innermost dimension.
+/// InstData layout is first set to be {1, .., subgroupSize}.
+/// Lane layout is first set to be {1, ..., subgroupSize} with lane data {1,
+/// ..., 1}. The instData and laneData is then adjusted to contain packed data,
+/// by checking if the consumerLayout's innermost dimension.
 xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
     xegpu::LayoutKind layoutKind, VectorType resVectorTy,
     xegpu::DistributeLayoutAttr consumerLayout,
@@ -620,26 +608,20 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
   int packingFactor = bitwidth < packingSize ? packingSize / bitwidth : 1;
   int packedDataSize = subgroupSize * packingFactor;
 
-  switch (layoutKind) {
-  case xegpu::LayoutKind::Subgroup:
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
     assert(true &&
            "subgroup layout assignment not supported for insertStridedSlice.");
-    break;
-  case xegpu::LayoutKind::InstData:
+  } else if (layoutKind == xegpu::LayoutKind::InstData) {
     instData[resShapeSize - 1] = subgroupSize;
     if (consumerInstData[resShapeSize - 1] == packedDataSize)
       instData[resShapeSize - 1] = packedDataSize;
     requiredResLayout = xegpu::LayoutAttr::get(context, instData);
-    break;
-  case xegpu::LayoutKind::Lane:
+  } else if (layoutKind == xegpu::LayoutKind::Lane) {
     laneLayout[resShapeSize - 1] = subgroupSize;
     laneData[resShapeSize - 1] = 1;
     if (consumerLaneData[resShapeSize - 1] == packingFactor)
       laneData[resShapeSize - 1] = packingFactor;
     requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
-    break;
-  default:
-    llvm_unreachable("unsupported layout kind");
   }
   return requiredResLayout;
 }

>From f92c9629a4923a4c5d16decab39830e3e0f6df53 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 00:03:28 +0000
Subject: [PATCH 30/35] adress feedback

---
 .../XeGPU/Transforms/XeGPULayoutImpls.h       |  2 +-
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 46 +++++++++----------
 2 files changed, 22 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index b7b7cf9dafd38..078a2d1b2ad0d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -72,7 +72,7 @@ dropSgLayoutAndDataOnAttrs(ArrayRef<NamedAttribute> attrs);
 SmallVector<NamedAttribute> dropInstDataOnAttrs(ArrayRef<NamedAttribute> attrs);
 
 /// Infers the source layout attribute for a broadcast operation given the
-/// result layout attribute, result shape, source shape, and broadcasted dims.
+/// result layout attribute, result shape, and source shape.
 DistributeLayoutAttr inferBroadcastSourceLayout(DistributeLayoutAttr resLayout,
                                                 ArrayRef<int64_t> resShape,
                                                 ArrayRef<int64_t> srcShape);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 4732da6db6f7a..213e3c8a8293c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -120,18 +120,14 @@ xegpu::removeLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand);
 
 void xegpu::removeLayoutAttrs(Operation *op) {
   op->walk([&](Operation *nestOp) {
-    for (OpOperand &opr : nestOp->getOpOperands())
-      removeLayoutAttr(opr);
-    for (OpResult result : nestOp->getOpResults())
-      removeLayoutAttr(result);
-    if (op->hasAttrOfType<DistributeLayoutAttr>("layout"))
-      op->removeAttr("layout");
-    if (op->hasAttrOfType<DistributeLayoutAttr>("layout_a"))
-      op->removeAttr("layout_a");
-    if (op->hasAttrOfType<DistributeLayoutAttr>("layout_b"))
-      op->removeAttr("layout_b");
-    if (op->hasAttrOfType<DistributeLayoutAttr>("layout_cd"))
-      op->removeAttr("layout_cd");
+    // Remove all attributes of DistributeLayoutAttr type
+    SmallVector<StringAttr> attrsToRemove;
+    for (auto namedAttr : nestOp->getAttrs()) {
+      if (isa<DistributeLayoutAttr>(namedAttr.getValue()))
+        attrsToRemove.push_back(namedAttr.getName());
+    }
+    for (auto attrName : attrsToRemove)
+      nestOp->removeAttr(attrName);
   });
 }
 
@@ -448,15 +444,15 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
       consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
                           : consumerLayout;
 
-  auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
-  const int workgroupSize = std::accumulate(
-      sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
   const int subgroupSize = uArch->getSubgroupSize();
   int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
 
   xegpu::DistributeLayoutAttr srcLayout;
 
   if (layoutKind == xegpu::LayoutKind::Subgroup) {
+    auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
+    const int workgroupSize = std::accumulate(
+        sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
     SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
     SmallVector<int64_t> consumerSgLayout =
         consumerLayout.getEffectiveSgLayoutAsInt();
@@ -631,17 +627,21 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
 /// same layout setup logic.
 /// For Subgroup layout, uses the consumer layout directly.
 /// non-chunked loads:
-///   InstData = {1, ..., min(consumer, maxLaneLoadStoreSize * subgroupSize)}
+///   InstData = {1, ..., min(consumer, maxLaneLoadSize * subgroupSize)}
 ///   LaneLayout = {1, ..., subgroupSize}
-///   lane_data = {1, ..., min(consumer, maxLaneLoadStoreSize)}
+///   lane_data = {1, ..., min(consumer, maxLaneLoadSize)}
 /// chunked loads:
-///   InstData = {subgroupSize, min(consumer, maxLaneLoadStoreSize)}
+///   InstData = {subgroupSize, min(consumer, maxLaneLoadSize)}
 ///   LaneLayout = {subgroupSize, 1}
-///   lane_data={1,min(consumer, maxLaneLoadStoreSize)}
+///   lane_data={1,min(consumer, maxLaneLoadSize)}
 static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
     xegpu::LayoutKind layoutKind, mlir::MLIRContext *context,
     xegpu::DistributeLayoutAttr consumerLayout, bool isChunkedLoad,
     int maxChunkSize, int valShapeSize, int subgroupSize) {
+
+  if (layoutKind == xegpu::LayoutKind::Subgroup)
+    return consumerLayout;
+
   SmallVector<int64_t> consumerInstData =
       consumerLayout.getEffectiveInstDataAsInt();
   SmallVector<int64_t> consumerLaneData =
@@ -651,10 +651,6 @@ static xegpu::DistributeLayoutAttr setupGenericLoadAnchorLayout(
   SmallVector<int> laneLayout(valShapeSize, 1);
   SmallVector<int> laneData(valShapeSize, 1);
 
-  if (layoutKind == xegpu::LayoutKind::Subgroup) {
-    return consumerLayout;
-  }
-
   if (!isChunkedLoad) {
     if (layoutKind == xegpu::LayoutKind::InstData) {
       instData[valShapeSize - 1] =
@@ -731,9 +727,9 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
 ///   LaneLayout = {1, ..., subgroupSize}
 ///   lane_data = {1, ..., 1}
 /// chunked stores:
-///   InstData = {subgroupSize, min(srcVec, maxLaneLoadStoreSize)}
+///   InstData = {subgroupSize, min(srcVec, maxLaneStoreSize)}
 ///   LaneLayout = {subgroupSize, 1}
-///   lane_data={1,min(srcVec, maxLaneLoadStoreSize)}
+///   lane_data={1,min(srcVec, maxLaneStoreSize)}
 static xegpu::DistributeLayoutAttr
 setupGenericStoreAnchorLayout(xegpu::LayoutKind layoutKind,
                               mlir::MLIRContext *context, bool isChunkedStore,

>From e9a38ecd5a1b25c74ceefeba319e6931c46a9136 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 00:24:49 +0000
Subject: [PATCH 31/35] add load/store instruction interface

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 44 +++++--------------
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      | 44 +++++++++++++++++++
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 14 +++---
 3 files changed, 64 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0f138b9defb5c..92b47b9c7d448 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -215,48 +215,28 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
   const unsigned packedFormatBitSizeB;
 };
 
-struct StoreScatterInstruction : public Instruction {
-  StoreScatterInstruction()
-      : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
-  static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::StoreScatter;
+struct LoadGatherInstruction : public LoadGatherInstructionInterface {
+  int32_t getMaxLaneLoadSize(int32_t bitWidth) const override {
+    return 16; // SPIRV restricts vector size
   }
-
-  // SPIRV restricts vector size
-  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
-struct LoadGatherInstruction : public Instruction {
-  LoadGatherInstruction()
-      : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
-  static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::LoadGather;
+struct StoreScatterInstruction : public StoreScatterInstructionInterface {
+  int32_t getMaxLaneStoreSize(int32_t bitWidth) const override {
+    return 16; // SPIRV restricts vector size
   }
-
-  // SPIRV restricts vector size
-  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
-struct StoreMatrixInstruction : public Instruction {
-  StoreMatrixInstruction()
-      : Instruction(InstructionKind::StoreMatrix, InstructionScope::Lane) {}
-  static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::StoreMatrix;
+struct LoadMatrixInstruction : public LoadMatrixInstructionInterface {
+  int32_t getMaxLaneLoadSize(int32_t bitWidth) const override {
+    return 16; // SPIRV restricts vector size
   }
-
-  // SPIRV restricts vector size
-  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
-struct LoadMatrixInstruction : public Instruction {
-  LoadMatrixInstruction()
-      : Instruction(InstructionKind::LoadMatrix, InstructionScope::Lane) {}
-  static bool classof(const Instruction *B) {
-    return B->getInstructionKind() == InstructionKind::LoadMatrix;
+struct StoreMatrixInstruction : public StoreMatrixInstructionInterface {
+  int32_t getMaxLaneStoreSize(int32_t bitWidth) const override {
+    return 16; // SPIRV restricts vector size
   }
-
-  // SPIRV restricts vector size
-  int32_t getMaxLaneLoadStoreSize() const { return 16; }
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 75a483baa116d..f927c75d3bbe5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -256,6 +256,50 @@ struct MMAInstructionInterface {
   virtual ~MMAInstructionInterface() = default;
 };
 
+struct LoadGatherInstructionInterface : public Instruction {
+  LoadGatherInstructionInterface()
+      : Instruction(InstructionKind::LoadGather, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::LoadGather;
+  }
+
+  virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
+  virtual ~LoadGatherInstructionInterface() = default;
+};
+
+struct StoreScatterInstructionInterface : public Instruction {
+  StoreScatterInstructionInterface()
+      : Instruction(InstructionKind::StoreScatter, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::StoreScatter;
+  }
+
+  virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
+  virtual ~StoreScatterInstructionInterface() = default;
+};
+
+struct LoadMatrixInstructionInterface : public Instruction {
+  LoadMatrixInstructionInterface()
+      : Instruction(InstructionKind::LoadMatrix, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::LoadMatrix;
+  }
+
+  virtual int32_t getMaxLaneLoadSize(int32_t bitWidth) const = 0;
+  virtual ~LoadMatrixInstructionInterface() = default;
+};
+
+struct StoreMatrixInstructionInterface : public Instruction {
+  StoreMatrixInstructionInterface()
+      : Instruction(InstructionKind::StoreMatrix, InstructionScope::Lane) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() == InstructionKind::StoreMatrix;
+  }
+
+  virtual int32_t getMaxLaneStoreSize(int32_t bitWidth) const = 0;
+  virtual ~StoreMatrixInstructionInterface() = default;
+};
+
 } // namespace uArch
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 213e3c8a8293c..c3845abbc81b9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -688,10 +688,11 @@ xegpu::DistributeLayoutAttr xegpu::setupLoadGatherAnchorLayout(
   const int subgroupSize = uArch->getSubgroupSize();
   int resShapeSize = resVecTy.getShape().size();
   auto context = resVecTy.getContext();
+  auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
 
   const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadGatherInstruction>(
       uArch->getInstruction(xegpu::uArch::InstructionKind::LoadGather));
-  int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+  int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
 
   return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
                                       (chunkSize > 1), maxChunkSize,
@@ -709,11 +710,11 @@ xegpu::setupLoadMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
   const int subgroupSize = uArch->getSubgroupSize();
   int resShapeSize = resVecTy.getShape().size();
   auto context = resVecTy.getContext();
+  auto elemBitWidth = resVecTy.getElementType().getIntOrFloatBitWidth();
 
   const auto *uArchInstruction = dyn_cast<xegpu::uArch::LoadMatrixInstruction>(
       uArch->getInstruction(xegpu::uArch::InstructionKind::LoadMatrix));
-  int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
-
+  int maxChunkSize = uArchInstruction->getMaxLaneLoadSize(elemBitWidth);
   return setupGenericLoadAnchorLayout(layoutKind, context, consumerLayout,
                                       false, maxChunkSize, resShapeSize,
                                       subgroupSize);
@@ -779,12 +780,12 @@ xegpu::setupStoreScatterAnchorLayout(xegpu::LayoutKind layoutKind,
   const int subgroupSize = uArch->getSubgroupSize();
   ArrayRef<int64_t> srcShape = srcVecTy.getShape();
   auto context = srcVecTy.getContext();
+  auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
 
   const auto *uArchInstruction =
       dyn_cast<xegpu::uArch::StoreScatterInstruction>(
           uArch->getInstruction(xegpu::uArch::InstructionKind::StoreScatter));
-  int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
-
+  int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
   return setupGenericStoreAnchorLayout(layoutKind, context, (chunkSize > 1),
                                        maxChunkSize, srcShape, subgroupSize);
 }
@@ -798,10 +799,11 @@ xegpu::setupStoreMatrixAnchorLayout(xegpu::LayoutKind layoutKind,
   const int subgroupSize = uArch->getSubgroupSize();
   ArrayRef<int64_t> srcShape = srcVecTy.getShape();
   auto context = srcVecTy.getContext();
+  auto elemBitWidth = srcVecTy.getElementType().getIntOrFloatBitWidth();
 
   const auto *uArchInstruction = dyn_cast<xegpu::uArch::StoreMatrixInstruction>(
       uArch->getInstruction(xegpu::uArch::InstructionKind::StoreMatrix));
-  int maxChunkSize = uArchInstruction->getMaxLaneLoadStoreSize();
+  int maxChunkSize = uArchInstruction->getMaxLaneStoreSize(elemBitWidth);
 
   return setupGenericStoreAnchorLayout(layoutKind, context, false, maxChunkSize,
                                        srcShape, subgroupSize);

>From e2fe8b892bf623d03d164df0f66094d960d6f51e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 18:30:07 +0000
Subject: [PATCH 32/35] address feedback: add check for collapseDims

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  8 ++++----
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 19 +++++++++++++++++--
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     |  3 +--
 3 files changed, 22 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 62ce7623ea14f..20dc8dd367203 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -246,8 +246,8 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                           "int64_t": $instData,
                           "int64_t": $laneData)>,              
     InterfaceMethod<[{Derive a new layout by collapsing groups of dimensions. Each inner array in
-                      `dimGroups` specifies a group of dimensions that are collapsed into a single
-                      dimension in the derived layout.}],
+                      `dimGroups` specifies a group of adjacent dimensions that are collapsed into
+                       a single dimension in the derived layout.}],
                     "xegpu::DistributeLayoutAttr",
                     "collapseDims",
                     (ins "SmallVector<SmallVector<int64_t>>": $dimGroups)>,                                                   
@@ -527,7 +527,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
     // Derive a new layout by collapsing groups of dimensions.
-    // Each inner array in `dimGroups` specifies a set of dimensions
+    // Each inner array in `dimGroups` specifies a set of adjacent dimensions
     // that are collapsed into a single dimension in the derived layout.
     DistributeLayoutAttr collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const;
 
@@ -708,7 +708,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     DistributeLayoutAttr setDimData(int64_t dim, int64_t sgData, int64_t instData, int64_t laneData);
 
     // Derive a new layout by collapsing groups of dimensions.
-    // Each inner array in `dimGroups` specifies a set of dimensions
+    // Each inner array in `dimGroups` specifies a set of adjacent dimensions
     // that are collapsed into a single dimension in the derived layout.
     DistributeLayoutAttr collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const;
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 7df4af66e7b01..b36149baadec9 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -502,7 +502,7 @@ DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
 }
 
 // Derive a new layout by collapsing groups of dimensions.
-// Each inner array in `dimGroups` specifies a set of dimensions
+// Each inner array in `dimGroups` specifies a set of adjacent dimensions
 // that are collapsed into a single dimension in the derived layout.
 DistributeLayoutAttr
 LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
@@ -527,6 +527,7 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
   SmallVector<int64_t> collapsedLaneLayout;
   SmallVector<int64_t> collapsedLaneData;
   SmallVector<int64_t> collapsedOrder;
+  SetVector<int64_t> coveredDims;
 
   for (const auto &group : dimGroups) {
 
@@ -534,8 +535,17 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
     int64_t collapsedSg = 1, collapsedSgD = 1, collapsedInst = 1;
     int64_t collapsedLaneL = 1, collapsedLaneD = 1;
     int64_t collapsedOrderValue = -1;
-
+    int64_t dimBeforeCurrent = group.front() - 1;
     for (int64_t dimIdx : group) {
+      // no two groups can cover the same dimension
+      if (!coveredDims.insert(dimIdx))
+        llvm::report_fatal_error(Twine("dimension ") + Twine(dimIdx) +
+                                 " is covered more than once");
+      // dims within group must be adjacent
+      if (dimBeforeCurrent != (dimIdx - 1))
+        llvm::report_fatal_error("dimensions being collapsed must be adjacent");
+      dimBeforeCurrent = dimIdx;
+
       collapsedSg *= sgLayout[dimIdx];
       collapsedSgD *= sgData[dimIdx];
       collapsedInst *= instData[dimIdx];
@@ -553,6 +563,11 @@ LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
     collapsedOrder.push_back(collapsedOrderValue);
   }
 
+  // check covered all dimensions
+  if (coveredDims.size() != sgLayout.size())
+    llvm::report_fatal_error(
+        "not all dimensions are covered in collapseGroups");
+
   // Create collapsed layout
   SmallVector<int32_t> collapsedSgLayout32(collapsedSgLayout.begin(),
                                            collapsedSgLayout.end());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index c3845abbc81b9..201075b900fdb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -225,7 +225,6 @@ xegpu::inferBitCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
     }
   }
 
-  // Now set only instData and laneData, preserving sgData
   xegpu::DistributeLayoutAttr finalSrcLayout;
   finalSrcLayout =
       resLayout.setDimData(dim, sgDataValue, instDataValue, laneDataValue);
@@ -327,11 +326,11 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
                             ArrayRef<int64_t> dst) -> bool {
     // each dim in src can be mapped to one or more dims in dst whose product
     // equals to the src dim
-    splitDimGroups.clear();
     size_t srcIdx = 0;
     int64_t accumulatedSize = 1;
     SmallVector<int64_t> currentDstDims;
 
+    splitDimGroups.clear();
     for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
       if (srcIdx >= src.size())
         return false;

>From 3b0bf7a3782200b811fba713d0713993a025f03f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 19:42:29 +0000
Subject: [PATCH 33/35] address feedback: move Dim expansion checkers to
 utility functions

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 12 ++--
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  9 +++
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 32 +++++------
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 55 ++-----------------
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  5 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 23 +-------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 55 +++++++++++++++++++
 7 files changed, 97 insertions(+), 94 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 20dc8dd367203..1251631955580 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -226,11 +226,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Derive a new layout with sg_data, inst_data and lane_data set to 1 for the specified unit dims",
                     "xegpu::DistributeLayoutAttr",
                     "setUnitDimData",
-                    /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
+                    /*args=*/(ins "const SmallVector<int64_t>": $unitDims)>,
     InterfaceMethod<"Derive a new layout with sg_lane and lane_layout set to 1 for the specified unit dims",
                     "xegpu::DistributeLayoutAttr",
                     "setUnitDimLayout",
-                    /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
+                    /*args=*/(ins "const SmallVector<int64_t>": $unitDims)>,
     InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
                       indices based on the effective layout level.}],
                     "FailureOr<SmallVector<Value>>",
@@ -516,10 +516,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     }
 
     //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
-    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
+    DistributeLayoutAttr setUnitDimData(SmallVector<int64_t> unitDims) const;
 
     //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
+    DistributeLayoutAttr setUnitDimLayout(SmallVector<int64_t> unitDims) const;
 
     // Derive a new layout with sg_data, inst_data and lane_data set to the 
     // specified values for the given dimension. Passing -1 for any parameter 
@@ -697,10 +697,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     }
 
     //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
-    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
+    DistributeLayoutAttr setUnitDimData(SmallVector<int64_t> unitDims) const;
 
     //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
+    DistributeLayoutAttr setUnitDimLayout(SmallVector<int64_t> unitDims) const;
 
     // Derive a new layout with sg_data, inst_data and lane_data set to the 
     // specified values for the given dimension. Passing -1 for any parameter 
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index d1fee2126daee..4443f86d1e4e2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -186,6 +186,15 @@ bool requirePacked(const LayoutAttr layout);
 /// Helper function to check if the layout requires a transpose effect.
 bool requireTranspose(const LayoutAttr layout, const uArch::uArch *uArch);
 
+// Check if dst shape is an expansion of src shape by inserting unit dimensions.
+bool matchUnitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
+                           SmallVector<int64_t> &expandedUnitDims);
+
+// Checks if dst shape is an expansion of src shape where each dimension in src
+// is split into one or more consecutive dimensions in dst
+bool matchSplitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
+                            SmallVector<SmallVector<int64_t>> &splitDimGroups);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index b36149baadec9..dcf128c94d20e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -398,7 +398,7 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
 
 // set the layout for unit dims: sg_data, inst_data and lane_data to 1
 DistributeLayoutAttr
-LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
+LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
   auto sgDataOpt = getSgData();
   auto instDataOpt = getInstData();
   auto laneDataOpt = getLaneData();
@@ -439,7 +439,7 @@ LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
 
 // set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
 DistributeLayoutAttr
-LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
+LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
   auto sgLayoutOpt = getSgLayout();
   auto laneLayoutOpt = getLaneLayout();
 
@@ -767,8 +767,8 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
 // shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
 // to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
 // it maps to dim [2] in parent space.
-static SetVector<int64_t>
-mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
+static SmallVector<int64_t>
+mapSlicedDimsToParentSpace(const SmallVector<int64_t> &dimsToMap,
                            ArrayRef<int64_t> sliceDims) {
   // Rather than recovering the exact parent rank, we compute a safe upper bound
   // so that dimsToMap can be adjusted safely. This upper bound is defined as
@@ -791,10 +791,10 @@ mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
   }
 
   // Map unit dims from sliced space to parent space
-  SetVector<int64_t> adjustUnitDims;
+  SmallVector<int64_t> adjustUnitDims;
   for (auto dim : dimsToMap) {
     int64_t mappedDim = remainingDims[dim];
-    adjustUnitDims.insert(mappedDim);
+    adjustUnitDims.push_back(mappedDim);
   }
 
   return adjustUnitDims;
@@ -802,12 +802,12 @@ mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
 
 // set the layout for unit dims: sg_data, inst_data and lane_data to 1
 DistributeLayoutAttr
-SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
+SliceAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
   DistributeLayoutAttr parentLayout = getParent();
 
   ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
 
-  SetVector<int64_t> adjustUnitDims =
+  SmallVector<int64_t> adjustUnitDims =
       mapSlicedDimsToParentSpace(unitDims, sliceDims);
 
   return SliceAttr::get(getContext(),
@@ -816,12 +816,12 @@ SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
 
 // set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
 DistributeLayoutAttr
-SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
+SliceAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
   DistributeLayoutAttr parentLayout = getParent();
 
   ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
 
-  SetVector<int64_t> adjustUnitDims =
+  SmallVector<int64_t> adjustUnitDims =
       mapSlicedDimsToParentSpace(unitDims, sliceDims);
 
   return SliceAttr::get(
@@ -835,10 +835,10 @@ DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
   ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
   auto parent = dyn_cast<LayoutAttr>(getParent());
 
-  SetVector<int64_t> dimSet;
-  dimSet.insert(dim);
-  SetVector<int64_t> adjustDims = mapSlicedDimsToParentSpace(dimSet, sliceDims);
-
+  SmallVector<int64_t> dimSet;
+  dimSet.push_back(dim);
+  SmallVector<int64_t> adjustDims =
+      mapSlicedDimsToParentSpace(dimSet, sliceDims);
   return SliceAttr::get(
       getContext(),
       parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
@@ -856,8 +856,8 @@ SliceAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
   // go through dimGroups and map each dim from sliced space to parent space
   SmallVector<SmallVector<int64_t>> adjustedDimGroups;
   for (const auto &group : dimGroups) {
-    SetVector<int64_t> groupSet(group.begin(), group.end());
-    SetVector<int64_t> mappedDims =
+    SmallVector<int64_t> groupSet(group.begin(), group.end());
+    SmallVector<int64_t> mappedDims =
         mapSlicedDimsToParentSpace(groupSet, sliceDims);
     adjustedDimGroups.push_back(
         SmallVector<int64_t>(mappedDims.begin(), mappedDims.end()));
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 201075b900fdb..407ee36d16769 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -294,22 +294,8 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   // Use case 1: Check if shapes only differ by expanding unit dimensions (like
   // expand_dims)
   SmallVector<int64_t> expandedUnitDims;
-  auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
-                                     ArrayRef<int64_t> dst) -> bool {
-    // All unit dimensions in dst that don't appear in src are the expanded
-    // unit dimensions
-    size_t srcIdx = 0;
-    for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
-      if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
-        srcIdx++;
-      else if (dst[dstIdx] == 1)
-        expandedUnitDims.push_back(dstIdx);
-      else
-        return false;
-    return srcIdx == src.size();
-  };
 
-  if (checkOnlyExpandUnitDims(srcShape, resShape)) {
+  if (xegpu::matchUnitDimExpansion(srcShape, resShape, expandedUnitDims)) {
     // create a slice layout for the source by removing the expanded unit dims
     auto sliceDimsAttr = DenseI64ArrayAttr::get(
         resLayout.getContext(), ArrayRef<int64_t>(expandedUnitDims));
@@ -321,42 +307,11 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
   // Maps each source dimension to the range of destination dimensions it splits
   // into
   SmallVector<SmallVector<int64_t>> splitDimGroups;
-
-  auto checkSplitDims = [&](ArrayRef<int64_t> src,
-                            ArrayRef<int64_t> dst) -> bool {
-    // each dim in src can be mapped to one or more dims in dst whose product
-    // equals to the src dim
-    size_t srcIdx = 0;
-    int64_t accumulatedSize = 1;
-    SmallVector<int64_t> currentDstDims;
-
-    splitDimGroups.clear();
-    for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
-      if (srcIdx >= src.size())
-        return false;
-      accumulatedSize *= dst[dstIdx];
-      currentDstDims.push_back(dstIdx);
-
-      if (accumulatedSize == src[srcIdx]) {
-        // Record the mapping: srcIdx -> currentDstDims
-        splitDimGroups.push_back(currentDstDims);
-        // move to next src dim
-        srcIdx++;
-        accumulatedSize = 1;
-        currentDstDims.clear();
-      } else if (accumulatedSize > src[srcIdx]) {
-        return false;
-      }
-    }
-    return srcIdx == src.size();
-  };
-
-  if (checkSplitDims(srcShape, resShape)) {
+  if (xegpu::matchSplitDimExpansion(srcShape, resShape, splitDimGroups))
     return resLayout.collapseDims(splitDimGroups);
-  }
 
-  auto checkCombineToInnerMostDim = [&](ArrayRef<int64_t> src,
-                                        ArrayRef<int64_t> dst) -> bool {
+  auto matchCollapseToInnermostDim = [&](ArrayRef<int64_t> src,
+                                         ArrayRef<int64_t> dst) -> bool {
     // only one non-unit dim in dst which is the innermost dim
     if ((dst.size() != 2) && (dst.size() != 1))
       return false;
@@ -367,7 +322,7 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
     return (dst[0] == 1) && (dst[1] == srcSize);
   };
 
-  if (checkCombineToInnerMostDim(srcShape, resShape)) {
+  if (matchCollapseToInnermostDim(srcShape, resShape)) {
     int srcShapeSize = srcShape.size();
     int resShapeSize = resShape.size();
     auto context = resLayout.getContext();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 75875c60f1012..4b762fb363fca 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1533,8 +1533,9 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
       }
       // case 2: source and result have same rank
       if (rankDiff == 0) {
-        SetVector<int64_t> broadcastUnitDims =
-            broadcastOp.computeBroadcastedUnitDims();
+        auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
+        SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
+                                               broadcastUnitDimsSet.end());
         bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
         if (!isEqualTo)
           return rewriter.notifyMatchFailure(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c0487cd709c3d..a90c218bdaae3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1114,27 +1114,10 @@ struct WgToSgVectorShapeCastOp
       return failure();
 
     ArrayRef<int64_t> srcShape = srcType.getShape();
-    llvm::SetVector<int64_t> expandedUnitDims;
-
-    // Check if shapes only differ by expanding unit dimensions (like
-    // expand_dims)
-    auto checkOnlyExpandUnitDims = [&](ArrayRef<int64_t> src,
-                                       ArrayRef<int64_t> dst) -> bool {
-      // All unit dimensions in dst that don't appear in src are the expanded
-      // unit dimensions
-      size_t srcIdx = 0;
-      for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
-        if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
-          srcIdx++;
-        else if (dst[dstIdx] == 1)
-          expandedUnitDims.insert(dstIdx);
-        else
-          return false;
-      return srcIdx == src.size();
-    };
-    xegpu::DistributeLayoutAttr layoutToDistribute = layout;
 
-    if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
+    xegpu::DistributeLayoutAttr layoutToDistribute = layout;
+    SmallVector<int64_t> expandedUnitDims;
+    if (xegpu::matchUnitDimExpansion(srcShape, wgShape, expandedUnitDims)) {
       xegpu::DistributeLayoutAttr sourceLayout =
           xegpu::getTemporaryLayout(op->getOpOperand(0));
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 259c5b3fa89c8..13d5a5fd54023 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -681,3 +681,58 @@ bool xegpu::requireTranspose(const xegpu::LayoutAttr layout,
     return false;
   return laneLayout[0] == uArch->getSubgroupSize() && laneLayout[1] == 1;
 }
+
+// Check if dst shape is an expansion of src shape by inserting unit dimensions.
+// Returns true if all dimensions in src match corresponding dimensions in dst
+// (after skipping unit dimensions), and populates expandedUnitDims with the
+// indices of the unit dimensions in dst that were added (not present in src).
+// Example: src=[2,3], dst=[1,2,3,1] -> true, expandedUnitDims=[0,3]
+bool xegpu::matchUnitDimExpansion(ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
+                                  SmallVector<int64_t> &expandedUnitDims) {
+  // All unit dimensions in dst that don't appear in src are the expanded
+  // unit dimensions
+  size_t srcIdx = 0;
+  for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
+    if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
+      srcIdx++;
+    else if (dst[dstIdx] == 1)
+      expandedUnitDims.push_back(dstIdx);
+    else
+      return false;
+  return srcIdx == src.size();
+};
+
+// Checks if dst shape is an expansion of src shape where each dimension in src
+// is split into one or more consecutive dimensions in dst whose product equals
+// the original dimension. Populates splitDimGroups with groups of dst indices
+// that correspond to each src dimension. Example: src=[6,4], dst=[2,3,2,2] ->
+// true
+bool xegpu::matchSplitDimExpansion(
+    ArrayRef<int64_t> src, ArrayRef<int64_t> dst,
+    SmallVector<SmallVector<int64_t>> &splitDimGroups) {
+  // each dim in src can be mapped to one or more dims in dst whose product
+  // equals to the src dim
+  size_t srcIdx = 0;
+  int64_t accumulatedSize = 1;
+  SmallVector<int64_t> currentDstDims;
+
+  splitDimGroups.clear();
+  for (size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
+    if (srcIdx >= src.size())
+      return false;
+    accumulatedSize *= dst[dstIdx];
+    currentDstDims.push_back(dstIdx);
+
+    if (accumulatedSize == src[srcIdx]) {
+      // Record the mapping: srcIdx -> currentDstDims
+      splitDimGroups.push_back(currentDstDims);
+      // move to next src dim
+      srcIdx++;
+      accumulatedSize = 1;
+      currentDstDims.clear();
+    } else if (accumulatedSize > src[srcIdx]) {
+      return false;
+    }
+  }
+  return srcIdx == src.size();
+};

>From ff8691babeb45bce083d1755c259f20f7b97f191 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 3 Feb 2026 23:52:27 +0000
Subject: [PATCH 34/35] adding comments and tests

---
 .../XeGPU/Transforms/XeGPULayoutImpls.h       |   2 +-
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     | 120 +++++++++++++++---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |   2 +-
 .../XeGPU/propagate-layout-inst-data.mlir     |  38 ++++++
 .../XeGPU/propagate-layout-subgroup.mlir      |  18 +++
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  39 ++++++
 6 files changed, 202 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
index 078a2d1b2ad0d..758b783953732 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpls.h
@@ -135,7 +135,7 @@ DistributeLayoutAttr setupBitCastResultLayout(
 /// Creates a result layout based on the specified layout kind (InstData or
 /// Lane).
 DistributeLayoutAttr setupInsertStridedSliceResultLayout(
-    LayoutKind layoutKind, VectorType resVectorTy,
+    LayoutKind layoutKind, VectorType srcVectorTy, VectorType resVectorTy,
     DistributeLayoutAttr consumerLayout, const uArch::uArch *uArch);
 
 /// Sets up the anchor layout for a load gather operation.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index 407ee36d16769..a008990a0a063 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -330,22 +330,42 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
     auto resLaneLayout = resLayout.getEffectiveLaneLayoutAsInt();
     auto resLaneData = resLayout.getEffectiveLaneDataAsInt();
 
-    // get the layout info from the innermost dim of result layout
+    // Extract layout info from result's innermost dimension and apply to
+    // source's innermost dimension while setting all other dimensions to 1.
+    // The inferred layout is restricted by srcShape to ensure it fits within
+    // the source dimensions.
+    // Examples:
+    //   srcShape=[8, 16, 32], resShape=[1, 4096], resInstData=[1, 16]
+    //   -> inferredInstData=[1, 1, min(16, 32)]=[1, 1, 16]
+    //   srcShape=[4, 8, 64], resShape=[2048], resLaneLayout=[16],
+    //   resLaneData=[2]
+    //   -> inferredLaneData=[1, 1, min(2, 64/16)]=[1, 1, 2]
+    //   -> inferredLaneLayout=[1, 1, 16]
     if (resInstData.size() != 0) {
+      // assert resInstData must be 1 for all but the innermost dim
+      for (int i = 0; i < resShapeSize - 1; i++) {
+        assert(resInstData[i] == 1 &&
+               "only innermost dim can have non-unit instData");
+      }
       SmallVector<int> inferredInstData(srcShapeSize, 1);
-      assert((resShapeSize == 1 || resInstData[0] == 1) &&
-             "only innermost dim can have data and instData layout");
-      inferredInstData[srcShapeSize - 1] = resInstData[resShapeSize - 1];
+      inferredInstData[srcShapeSize - 1] =
+          std::min(resLaneData[resShapeSize - 1], srcShape[srcShapeSize - 1]);
       return xegpu::LayoutAttr::get(context, inferredInstData);
     }
 
     if (resLaneLayout.size() != 0) {
+      for (int i = 0; i < resShapeSize - 1; i++) {
+        assert(resLaneData[i] == 1 &&
+               "only innermost dim can have non-unit instData");
+      }
+      assert("srcShape[srcShapeSize - 1] >= resLaneLayout[resShapeSize - 1]" &&
+             "source innermost dim must be >= result lane layout");
       SmallVector<int> inferredLaneLayout(srcShapeSize, 1);
       SmallVector<int> inferredLaneData(srcShapeSize, 1);
-      assert((resShapeSize == 1 || resLaneLayout[0] == 1) &&
-             "only innermost dim can have data and lane layout");
       inferredLaneLayout[srcShapeSize - 1] = resLaneLayout[resShapeSize - 1];
-      inferredLaneData[srcShapeSize - 1] = resLaneData[resShapeSize - 1];
+      inferredLaneData[srcShapeSize - 1] = std::min(
+          resLaneData[resShapeSize - 1],
+          srcShape[srcShapeSize - 1] / inferredLaneLayout[srcShapeSize - 1]);
       return xegpu::LayoutAttr::get(context, inferredLaneLayout,
                                     inferredLaneData);
     }
@@ -365,12 +385,36 @@ xegpu::inferShapeCastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 /// For subgroup layouts, it first tries to align the source layout's subgroup
 /// layout and data with the consumer's layout on non-reduction dimensions.
 /// Then, it distributes remaining subgroups across reduction dimensions. This
-/// avoid subgroup data redistribution overhead between the reduced result and
+/// avoids subgroup data redistribution overhead between the reduced result and
 /// its consumer.
 ///
 /// InstData requries {1, ..., min(maxReduceVectorSize, srcShape),subgroupSize}
 /// Lane Layout requires {1, ..., 1, subgroupSize}
 /// Lane data requires {1, ..., min(maxReduceVectorSize, srcShape), 1}
+///
+/// Examples:
+///   1. Subgroup layout - Row reduction on 2D tensor:
+///      srcShape=[32, 64], reductionDims=[1], resShape=[32], subgroupSize=16,
+///      workgroupSize=32
+///      Consumer Layout:
+///      #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 8]>, dims =
+///      [0]>} Result: srcLayout with sgLayout=[4, 8], sgData=[8, 8] (matches
+///      consumer on non-reduction dim, minimizing data redistribution on
+///      reduction dim)
+///   2. Subgroup layout - Same example above but consumer has different layout:
+///      sgLayout=[32], sgData=[1]
+///      Result: srcLayout with sgLayout=[32,1], sgData=[1, 64]
+///      (distributes all subgroups on non reduction dim)
+///
+///   2. InstData layout - Column reduction:
+///      srcShape=[32, 64], reductionDims=[0], subgroupSize=16
+///      Result: instData=[1, 16] (maxReduceVectorSize=1, subgroupSize on
+///      innermost)
+///
+///   3. Lane layout - Multi-dimensional reduction:
+///      srcShape=[16, 32, 64], reductionDims=[1], subgroupSize=16
+///      Result: laneLayout=[1, 1, 16], laneData=[1, 1, 1]
+///      (subgroupSize on innermost dim, max vector size on reduction dim)
 
 xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy,
@@ -438,9 +482,10 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     }
 
     assert(remainingSgCount == 1 && "not all subgroups distributed");
-    srcLayout =
-        xegpu::LayoutAttr::get(context, toInt32Attr(sgLayout),
-                               toInt32Attr(sgData), consumerLayout.getOrder());
+    srcLayout = xegpu::LayoutAttr::get(
+        context, toInt32Attr(sgLayout), toInt32Attr(sgData),
+        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
+        /*lane_data =*/nullptr, /*order =*/nullptr);
 
   } else if (layoutKind == xegpu::LayoutKind::InstData) {
 
@@ -470,6 +515,23 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
 /// instData, or laneData) by multiplying by the bitwidth ratio to ensure the
 /// result layout can be correctly divided back to the source layout during
 /// inference.
+///
+/// Examples:
+///   1. Casting f32 -> f16 (32-bit to 16-bit, bitWidthRatio = 2):
+///      Consumer layout: instData=[1, 16], subgroupSize=16
+///      Source shape: [8, 32]
+///      Result layout: instData=[1, 32] (16 * 2)
+///      The innermost dimension is multiplied by 2 to maintain consistency.
+///
+///   2. Casting f32 -> i8 (32-bit to 8-bit, bitWidthRatio = 4):
+///      Consumer instData=[1, 16], subgroupSize=16
+///      Source shape: [4, 128]
+///      adjust the instData from [1, 16] to [1, 16 * 4 = 64]
+///
+///   3. Casting i8 -> i32 (8-bit to 32-bit, bitWidthRatio = 1/4):
+///      Consumer layout: laneLayout=[1, 16], laneData=[1, 4]
+///      No adjustment needed - returns consumer layout directly.
+///
 xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
     xegpu::LayoutKind layoutKind, VectorType srcVecTy, VectorType resVecTy,
     DistributeLayoutAttr consumerLayout, const xegpu::uArch::uArch *uArch) {
@@ -534,9 +596,31 @@ xegpu::DistributeLayoutAttr xegpu::setupBitCastResultLayout(
 /// Lane layout is first set to be {1, ..., subgroupSize} with lane data {1,
 /// ..., 1}. The instData and laneData is then adjusted to contain packed data,
 /// by checking if the consumerLayout's innermost dimension.
+///
+/// Examples:
+///   1. InstData layout without packing:
+///      resShape=[8, 32], subgroupSize=16, bitwidth=32
+///      packingFactor=1, packedDataSize=16
+///      consumerLayout: instData=[1, 16]
+///      Result: instData=[1, 16]
+///
+///   2. InstData layout with packing:
+///      resShape=[8, 64], subgroupSize=16, bitwidth=8, packingFactor=4
+///      consumerLayout: instData=[1, 64]
+///      Result: instData=[1, 64] (adjusted for packed data)
+///
+///   3. Lane layout without packing:
+///      resShape=[4, 64], subgroupSize=16, bitwidth=32
+///      consumerLayout: laneLayout=[1, 16], laneData=[1, 1]
+///      Result: laneLayout=[1, 16], laneData=[1, 1]
+///
+///   4. Lane layout with packing:
+///      resShape=[4, 64], subgroupSize=16, bitwidth=16, packingFactor=2
+///      consumerLayout: laneLayout=[1, 16], laneData=[1, 2]
+///      Result: laneLayout=[1, 16], laneData=[1, 2] (adjusted for packed data)
 xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
-    xegpu::LayoutKind layoutKind, VectorType resVectorTy,
-    xegpu::DistributeLayoutAttr consumerLayout,
+    xegpu::LayoutKind layoutKind, VectorType srcVectorTy,
+    VectorType resVectorTy, xegpu::DistributeLayoutAttr consumerLayout,
     const xegpu::uArch::uArch *uArch) {
 
   xegpu::DistributeLayoutAttr requiredResLayout;
@@ -544,6 +628,8 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
   auto context = resVectorTy.getContext();
   auto resShape = resVectorTy.getShape();
   int resShapeSize = resShape.size();
+  auto srcShape = srcVectorTy.getShape();
+  int srcShapeSize = srcVectorTy.getShape().size();
   SmallVector<int64_t> consumerInstData =
       consumerLayout.getEffectiveInstDataAsInt();
   SmallVector<int64_t> consumerLaneData =
@@ -562,14 +648,18 @@ xegpu::DistributeLayoutAttr xegpu::setupInsertStridedSliceResultLayout(
     assert(true &&
            "subgroup layout assignment not supported for insertStridedSlice.");
   } else if (layoutKind == xegpu::LayoutKind::InstData) {
+    assert(srcShape[srcShapeSize - 1] >= subgroupSize &&
+           "source innermost dim must be >= subgroupSize");
     instData[resShapeSize - 1] = subgroupSize;
-    if (consumerInstData[resShapeSize - 1] == packedDataSize)
+    if (consumerInstData[resShapeSize - 1] == packedDataSize &&
+        srcShape[srcShapeSize - 1] >= packedDataSize)
       instData[resShapeSize - 1] = packedDataSize;
     requiredResLayout = xegpu::LayoutAttr::get(context, instData);
   } else if (layoutKind == xegpu::LayoutKind::Lane) {
     laneLayout[resShapeSize - 1] = subgroupSize;
     laneData[resShapeSize - 1] = 1;
-    if (consumerLaneData[resShapeSize - 1] == packingFactor)
+    if (consumerLaneData[resShapeSize - 1] == packingFactor &&
+        srcShape[srcShapeSize - 1] >= packedDataSize)
       laneData[resShapeSize - 1] = packingFactor;
     requiredResLayout = xegpu::LayoutAttr::get(context, laneLayout, laneData);
   }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index b723036449914..107192f162da4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -1114,7 +1114,7 @@ void LayoutInfoPropagation::visitInsertStridedSliceOp(
   auto uArch = getUArch(xegpu::getChipStr(insertStridedSlice).value_or(""));
 
   auto requiredResLayoutAttr = xegpu::setupInsertStridedSliceResultLayout(
-      layoutKind, resVecType, consumerLayoutAttr, uArch);
+      layoutKind, srcVecType, resVecType, consumerLayoutAttr, uArch);
 
   xegpu::setTemporaryLayout(insertStridedSlice->getResult(0),
                             requiredResLayoutAttr);
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index d53d8ea5bd643..595be3dfa29e5 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -234,3 +234,41 @@ func.func @scatter_ops_chunksize_slice(%src: memref<1024xf32>) {
   return
 }
 }
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_inst_data_no_packing(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf32>) {
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<1.000000e+00> : vector<4x16xf32>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>} dense<0.000000e+00> : vector<8x32xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 16]>, offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
+// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #xegpu.layout<inst_data = [8, 16]>>
+func.func @insert_strided_slice_inst_data_no_packing(%arg0: memref<8x32xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst_small = arith.constant dense<1.0> : vector<4x16xf32>
+  %cst_large = arith.constant dense<0.0> : vector<8x32xf32>
+  %insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<4x16xf32> into vector<8x32xf32>
+  %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xf32> -> !xegpu.tensor_desc<8x32xf32>
+  xegpu.store_nd %insert, %tdesc : vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32>
+  return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_inst_data_with_packing(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xi8>) {
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<1> : vector<4x64xi8>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>} dense<0> : vector<8x64xi8>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<inst_data = [1, 64]>, offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
+func.func @insert_strided_slice_inst_data_with_packing(%arg0: memref<8x64xi8>) {
+  %c0 = arith.constant 0 : index
+  %cst_small = arith.constant dense<1> : vector<4x64xi8>
+  %cst_large = arith.constant dense<0> : vector<8x64xi8>
+  %insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<4x64xi8> into vector<8x64xi8>
+  %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x64xi8> -> !xegpu.tensor_desc<8x64xi8, #xegpu.layout<inst_data = [8, 64]>>
+  xegpu.store_nd %insert, %tdesc <{layout = #xegpu.layout<inst_data = [8, 64]>}>: vector<8x64xi8>, !xegpu.tensor_desc<8x64xi8, #xegpu.layout<inst_data = [8, 64]>>
+  return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 29e5b51627fb6..d8668236116c8 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -123,3 +123,21 @@ gpu.module @test {
     gpu.return
   }
 }
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: vector_row_reduction_1
+// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [1, 64]>, dims = [1]>}
+  gpu.func @vector_row_reduction_1(%src: memref<32x64xf32>, %dst: memref<32xf32>) kernel attributes
+      {known_block_size = array<i32: 1, 32, 1>} {
+    %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+    %tdesc_src = xegpu.create_nd_tdesc %src : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32>
+    %load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x64xf32> -> vector<32x64xf32>
+    %reduce = vector.multi_reduction <add>, %load, %cst [1] : vector<32x64xf32> to vector<32xf32>
+    %tdesc_dst = xegpu.create_nd_tdesc %dst : memref<32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<sg_layout = [32], sg_data = [1]>>
+    xegpu.store_nd %reduce, %tdesc_dst <{layout = #xegpu.layout<sg_layout = [32], sg_data = [1]>}>
+      : vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.layout<sg_layout = [32], sg_data = [1]>>
+    gpu.return
+  }
+}
+
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 91e790ccdb4fe..90d580cd5901e 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -707,3 +707,42 @@ func.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
   return
 }
 }
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_lane_layout_no_packing(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x64xf32>) {
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<1.000000e+00> : vector<2x32xf32>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<4x64xf32>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, offsets = [0, 0], strides = [1, 1]} : vector<2x32xf32> into vector<4x64xf32>
+// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<4x64xf32> -> !xegpu.tensor_desc<4x64xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: xegpu.store_nd %[[INSERT]], %[[TDESC]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<4x64xf32>, !xegpu.tensor_desc<4x64xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @insert_strided_slice_lane_layout_no_packing(%arg0: memref<4x64xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst_small = arith.constant dense<1.0> : vector<2x32xf32>
+  %cst_large = arith.constant dense<0.0> : vector<4x64xf32>
+  %insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<2x32xf32> into vector<4x64xf32>
+  %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<4x64xf32> -> !xegpu.tensor_desc<4x64xf32>
+  xegpu.store_nd %insert, %tdesc : vector<4x64xf32>, !xegpu.tensor_desc<4x64xf32>
+  return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @insert_strided_slice_lane_layout_with_packing(
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x64xf16>) {
+// CHECK: %[[CST_SMALL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} dense<1.000000e+00> : vector<2x32xf16>
+// CHECK: %[[CST_LARGE:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} dense<0.000000e+00> : vector<4x64xf16>
+// CHECK: %[[INSERT:.*]] = vector.insert_strided_slice %[[CST_SMALL]], %[[CST_LARGE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, offsets = [0, 0], strides = [1, 1]} : vector<2x32xf16> into vector<4x64xf16>
+func.func @insert_strided_slice_lane_layout_with_packing(%arg0: memref<4x64xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst_small = arith.constant dense<1.0> : vector<2x32xf16>
+  %cst_large = arith.constant dense<0.0> : vector<4x64xf16>
+  %insert = vector.insert_strided_slice %cst_small, %cst_large {offsets = [0, 0], strides = [1, 1]} : vector<2x32xf16> into vector<4x64xf16>
+  %tdesc = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<4x64xf16> -> !xegpu.tensor_desc<4x64xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>>
+  xegpu.store_nd %insert, %tdesc <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}>: vector<4x64xf16>, !xegpu.tensor_desc<4x64xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>>
+  return
+}
+}
+

>From 4044aabb0c9e4959cc7a92e2a223a275df88b0db Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 4 Feb 2026 05:11:11 +0000
Subject: [PATCH 35/35] add reduction tests

---
 .../XeGPU/Transforms/XeGPULayoutImpls.cpp     |  4 ++--
 .../XeGPU/propagate-layout-subgroup.mlir      | 23 +++++++++++++++++++
 2 files changed, 25 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
index a008990a0a063..cd940198ef0c6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpls.cpp
@@ -472,8 +472,8 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     // Second pass: Distribute remaining subgroups across reduction dimensions
     for (int i = srcRank - 1; i >= 0; i--) {
       if (llvm::is_contained(reductionDims, i)) {
-        sgLayout[i] = std::min(srcShape[i] / subgroupSize,
-                               static_cast<int64_t>(remainingSgCount));
+        sgLayout[i] =
+            std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
         assert((srcShape[i] % sgLayout[i] == 0) &&
                "source shape not divisible by sg_layout");
         sgData[i] = srcShape[i] / sgLayout[i];
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index d8668236116c8..5e07e5336b382 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -141,3 +141,26 @@ gpu.module @test {
   }
 }
 
+// -----
+gpu.module @test {
+// CHECK-LABEL: vector_row_reduction_2
+  gpu.func @vector_row_reduction_2(%src: memref<32x128xf32>, %dst: memref<32xf32>) kernel attributes
+      {known_block_size = array<i32: 1, 32, 1>} {
+    %cst = arith.constant dense<0.000000e+00> : vector<32xf32>
+    %cst1 = arith.constant dense<0.000000e+00> : vector<32x128xf32>
+    %tdesc_src = xegpu.create_nd_tdesc %src : memref<32x128xf32> -> !xegpu.tensor_desc<32x128xf32>
+    %load = xegpu.load_nd %tdesc_src : !xegpu.tensor_desc<32x128xf32> -> vector<32x128xf32>
+    %bcast1 = vector.broadcast %load: vector<32x128xf32> to vector<4x32x128xf32>
+
+  // CHECK: %[[BCAST1:.*]] = vector.broadcast %{{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>} : vector<32x128xf32> to vector<4x32x128xf32>
+  // CHECK: %[[BCAST:.*]] = vector.multi_reduction <add>, %[[BCAST1]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 4, 8], sg_data = [4, 8, 16]>, dims = [0]>} [0] : vector<4x32x128xf32> to vector<32x128xf32>
+  // CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[BCAST]], %{{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [8, 16]>, dims = [1]>} [1] : vector<32x128xf32> to vector<32xf32>
+
+    %bcast = vector.multi_reduction <add>, %bcast1, %cst1 [0]: vector<4x32x128xf32> to vector<32x128xf32>
+    %reduce = vector.multi_reduction <add>, %bcast, %cst [1] : vector<32x128xf32> to vector<32xf32>
+    %mask = arith.constant dense<1>: vector<32xi1>
+    %offset = vector.step : vector<32xindex>
+    xegpu.store %reduce, %dst[%offset], %mask {layout = #xegpu.slice<#xegpu.layout<sg_layout=[4, 8], sg_data=[8, 16]>, dims = [1]>} : vector<32xf32>, memref<32xf32>, vector<32xindex>, vector<32xi1>
+    gpu.return
+  }
+}
\ No newline at end of file



More information about the Mlir-commits mailing list