[Mlir-commits] [mlir] [MLIR][XeGPU] add dpas_mx op definition and layout propagation rule (PR #194117)

Jianhui Li llvmlistbot at llvm.org
Fri Apr 24 22:58:00 PDT 2026


https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/194117

This PR extends the DpasMx operation to support MXFP (microscaling floating point) matrix multiply with separate scale factor layouts.

1. Op Definition
     Added layout_a_scale and layout_b_scale attributes to DpasMx op
     Removed AllElementTypesMatch<["a", "b"]> trait to allow different types for A/B with scales
2. Layout Infrastructure
    setupDpasMxLayout(): Creates anchor layouts for all 5 operands (A, B, C/D, scale_a, scale_b)
    Derives scale layouts from parent matrix layouts by dividing innermost dimension
    Supports all layout kinds: Subgroup, InstData, Lane
    Fix a bug in getupDpasSubgroupLayouts(): sg_data of A/B matrix should keep the full K dimension.
3. Layout Propagation

    visitDpasMxOp(): Propagates layout attributes from op to operands during dataflow analysis

>From 618b5fe929e70aada84a6baa8fafde7bb4f1697e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 23 Apr 2026 23:33:41 +0000
Subject: [PATCH 1/7] [mlir][XeGPU] Add DpasMx op definition and layout support

This patch extends the DpasMx operation with scale layout attributes and
implements layout setup and propagation for MXFP (microscaling floating point)
operations.

Op definition changes (XeGPUOps.td):
- Add layout_a_scale and layout_b_scale attributes to DpasMx op
- Remove restrictive AllElementTypesMatch trait to allow different types for
  A and B operands with scale factors

Layout support (XeGPULayoutImpl.cpp/h):
- setupDpasMxLayout: Creates anchor layouts for all DpasMx operands (A, B, C/D,
  scale_a, scale_b) across different layout kinds (Subgroup, InstData, Lane)
- Derives scale layouts from parent matrix layouts by adjusting dimensions based
  on scaling factors (K/32 for scales)

Layout propagation (XeGPUPropagateLayout.cpp):
- visitDpasMxOp: Propagates layout attributes from DpasMx op to its operands
  during dataflow analysis

This infrastructure enables proper layout tracking for mixed-precision matrix
operations with separate scale factors during workgroup-to-subgroup distribution.

Co-Authored-By: Claude Sonnet 4.5 <noreply at anthropic.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |   5 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |  18 ++-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 151 ++++++++++++++++++
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  39 +++++
 4 files changed, 205 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 31fe93d209a6d..8b6763e234091 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1554,7 +1554,6 @@ def XeGPU_TruncfOp
 }
 
 def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
-                                          AllElementTypesMatch<["a", "b"]>,
                                           AnchorLayoutInterface]> {
   let summary = "It performs scaled mma computation";
 
@@ -1601,7 +1600,9 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
                           VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_b,
       OptionalAttr<DistributeLayoutAttr>:$layout_a,
       OptionalAttr<DistributeLayoutAttr>:$layout_b,
-      OptionalAttr<DistributeLayoutAttr>:$layout_cd);
+      OptionalAttr<DistributeLayoutAttr>:$layout_cd,
+      OptionalAttr<DistributeLayoutAttr>:$layout_a_scale,
+      OptionalAttr<DistributeLayoutAttr>:$layout_b_scale);
   let results = (outs XeGPU_DpasResType:$result);
   let extraClassDeclaration = [{
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 83eb939cf1bec..c68e7334434d6 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -39,12 +39,6 @@ LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
 
 LogicalResult resolveLayoutConflicts(Operation *target);
 
-/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
-/// OpResult of of the given operation. If the operation contains regions, it is
-/// also applied recursively to the contained operations operation.
-/// TODO: To be replaced by recoverTemporaryLayouts()
-void recoverTemporaryLayoutsDeprecated(Operation *op);
-
 /// Attach layout attributes to all vector-type operands of operations within
 /// the given operation's nested region. Reports an error if any vector operand
 /// lacks a layout attribute.
@@ -199,6 +193,18 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
                 VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
                 const uArch::uArch *uArch);
 
+/// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
+/// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
+/// creation. A_scale and B_scale are optional.
+std::optional<std::tuple<DistributeLayoutAttr, DistributeLayoutAttr,
+                         DistributeLayoutAttr, DistributeLayoutAttr,
+                         DistributeLayoutAttr>>
+setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
+                  VectorType cdTy, std::optional<VectorType> aScaleTy,
+                  std::optional<VectorType> bScaleTy,
+                  DistributeLayoutAttr consumerLayout, int numSg,
+                  const uArch::uArch *uArch);
+
 /// Gets the expected layout for a given consumer operand. This will check if
 /// the owning operation of the consumer operand is one of the special layout
 /// users and determine the expected layout accordingly.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7d48315eec6ff..8f07ddf09e9a4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1426,3 +1426,154 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
   // the operand.
   return xegpu::getDistributeLayoutAttr(operand.get());
 }
+/// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
+/// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
+/// creation.
+std::optional<
+    std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
+               xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
+               xegpu::DistributeLayoutAttr>>
+xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
+                         VectorType bTy, VectorType cdTy,
+                         std::optional<VectorType> aScaleTy,
+                         std::optional<VectorType> bScaleTy,
+                         xegpu::DistributeLayoutAttr consumerLayout, int numSg,
+                         const xegpu::uArch::uArch *uArch) {
+  auto context = aTy.getContext();
+  const int subgroupSize = uArch->getSubgroupSize();
+
+  // Helper to create scale layout from parent layout
+  auto createScaleLayout = [&](VectorType parentTy, VectorType scaleTy,
+                               xegpu::DistributeLayoutAttr parentLayout,
+                               bool isBScale) -> xegpu::DistributeLayoutAttr {
+    if (!scaleTy || !parentLayout)
+      return nullptr;
+
+    // Calculate scaling factor by dividing parent shape by scale shape
+    ArrayRef<int64_t> parentShape = parentTy.getShape();
+    ArrayRef<int64_t> scaleShape = scaleTy.getShape();
+    int64_t scaleFactor = parentShape.back() / scaleShape.back();
+    int64_t rank = parentLayout.getRank();
+    assert(rank == 2 && "dpas layouts must be two dimensions");
+
+    SmallVector<int64_t> sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> sgData = parentLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> instData = parentLayout.getEffectiveInstDataAsInt();
+    SmallVector<int64_t> laneLayout =
+        parentLayout.getEffectiveLaneLayoutAsInt();
+    SmallVector<int64_t> laneData = parentLayout.getEffectiveLaneDataAsInt();
+    auto order = parentLayout.getOrder();
+
+    // Divide last dimension by scaling factor
+    if (!sgData.empty())
+      sgData.back() = sgData.back() / scaleFactor;
+    if (!instData.empty())
+      instData.back() = instData.back() / scaleFactor;
+
+    if (isBScale) {
+      // For B scale: lane_layout = [min(subgroupSize, scaleFactor), 1]
+      // lane_data = [1, number of scale elements]
+      laneLayout[rank - 2] =
+          std::min(static_cast<int64_t>(subgroupSize), scaleShape[rank-2]);
+      laneLayout[rank - 1] = 1;
+      laneData[rank - 2] = 1;
+      laneData[rank - 1] = scaleShape.back();
+    } else {
+      laneData.back() = scaleShape.back();
+    }
+
+    return xegpu::LayoutAttr::get(
+        context,
+        sgLayout.empty()
+            ? nullptr
+            : DenseI32ArrayAttr::get(
+                  context, SmallVector<int>(sgLayout.begin(), sgLayout.end())),
+        sgData.empty()
+            ? nullptr
+            : DenseI32ArrayAttr::get(
+                  context, SmallVector<int>(sgData.begin(), sgData.end())),
+        instData.empty()
+            ? nullptr
+            : DenseI32ArrayAttr::get(
+                  context, SmallVector<int>(instData.begin(), instData.end())),
+        laneLayout.empty() ? nullptr
+                           : DenseI32ArrayAttr::get(
+                                 context, SmallVector<int>(laneLayout.begin(),
+                                                           laneLayout.end())),
+        laneData.empty()
+            ? nullptr
+            : DenseI32ArrayAttr::get(
+                  context, SmallVector<int>(laneData.begin(), laneData.end())),
+        order);
+  };
+
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
+    assert(numSg > 0 &&
+           "Number of subgroups must be provided for sg layout creation.");
+    auto dpasLayouts =
+        getupDpasSubgroupLayouts(context, aTy, bTy, cdTy, consumerLayout, numSg, uArch);
+    if (!dpasLayouts)
+      return std::nullopt;
+
+    auto [dpasALayout, dpasBLayout, dpasCDLayout] = *dpasLayouts;
+
+    // Create scale layouts
+    auto aScaleLayout =
+        aScaleTy.has_value()
+            ? createScaleLayout(aTy, *aScaleTy, dpasALayout, false)
+            : nullptr;
+    auto bScaleLayout =
+        bScaleTy.has_value()
+            ? createScaleLayout(bTy, *bScaleTy, dpasBLayout, true)
+            : nullptr;
+
+    return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
+                           bScaleLayout);
+  } else if (layoutKind == xegpu::LayoutKind::InstData) {
+    auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
+    if (!instDataVecs)
+      return std::nullopt;
+    auto [instDataA, instDataB, instDataCD] = *instDataVecs;
+
+    auto dpasALayout = xegpu::LayoutAttr::get(
+        context, SmallVector<int>(instDataA.begin(), instDataA.end()));
+    auto dpasBLayout = xegpu::LayoutAttr::get(
+        context, SmallVector<int>(instDataB.begin(), instDataB.end()));
+    auto dpasCDLayout = xegpu::LayoutAttr::get(
+        context, SmallVector<int>(instDataCD.begin(), instDataCD.end()));
+
+    // Create scale layouts
+    auto aScaleLayout =
+        aScaleTy.has_value()
+            ? createScaleLayout(aTy, *aScaleTy, dpasALayout, false)
+            : nullptr;
+    auto bScaleLayout =
+        bScaleTy.has_value()
+            ? createScaleLayout(bTy, *bScaleTy, dpasBLayout, true)
+            : nullptr;
+
+    return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
+                           bScaleLayout);
+  } else if (layoutKind == xegpu::LayoutKind::Lane) {
+    const auto *uArchInstruction =
+        dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+            xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+    auto aLayout = getDefaultLaneLayout2DBlockIo(
+        aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
+    auto bLayout = getDefaultLaneLayout2DBlockIo(
+        bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
+    auto cdLayout = getDefaultLaneLayout2DBlockIo(cdTy, uArch);
+
+    // Create scale layouts
+    auto aScaleLayout = aScaleTy.has_value()
+                            ? createScaleLayout(aTy, *aScaleTy, aLayout, false)
+                            : nullptr;
+    auto bScaleLayout = bScaleTy.has_value()
+                            ? createScaleLayout(bTy, *bScaleTy, bLayout, true)
+                            : nullptr;
+
+    return std::make_tuple(aLayout, bLayout, cdLayout, aScaleLayout,
+                           bScaleLayout);
+  }
+  return std::nullopt;
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 43998ed41f7aa..f1709e95d34a9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -316,6 +316,9 @@ class LayoutInfoPropagation
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitDpasMxOp(xegpu::DpasMxOp dpasMx, ArrayRef<LayoutInfoLattice *> operands,
+                     ArrayRef<const LayoutInfoLattice *> results);
+
   void visitStoreNdOp(xegpu::StoreNdOp store,
                       ArrayRef<LayoutInfoLattice *> operands,
                       ArrayRef<const LayoutInfoLattice *> results);
@@ -426,6 +429,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
   TypeSwitch<Operation *>(op)
       .Case(
           [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
+      .Case([&](xegpu::DpasMxOp dpasMxOp) {
+        visitDpasMxOp(dpasMxOp, operands, results);
+      })
       .Case([&](xegpu::StoreNdOp storeNdOp) {
         visitStoreNdOp(storeNdOp, operands, results);
       })
@@ -804,6 +810,39 @@ void LayoutInfoPropagation::visitDpasOp(
     std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
 
     dpas.setLayoutAAttr(requiredALayout);
+
+/// Propagate layout for DpasMxOp operands using the layout attributes.
+/// DpasMxOp has operands: a, b, acc (optional), scale_a (optional), scale_b (optional)
+void LayoutInfoPropagation::visitDpasMxOp(
+    xegpu::DpasMxOp dpasMx, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+
+  // Get the layout attributes from the operation
+  xegpu::DistributeLayoutAttr layoutA = dpasMx.getLayoutAAttr();
+  xegpu::DistributeLayoutAttr layoutB = dpasMx.getLayoutBAttr();
+  xegpu::DistributeLayoutAttr layoutCD = dpasMx.getLayoutCdAttr();
+  xegpu::DistributeLayoutAttr layoutAScale = dpasMx.getLayoutAScaleAttr();
+  xegpu::DistributeLayoutAttr layoutBScale = dpasMx.getLayoutBScaleAttr();
+
+  // Propagate layouts to operands based on their positions:
+  // operands[0] = a, operands[1] = b, operands[2] = acc (optional),
+  // operands[3] = scale_a (optional), operands[4] = scale_b (optional)
+
+  if (layoutA && operands.size() > 0)
+    propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(layoutA)));
+
+  if (layoutB && operands.size() > 1)
+    propagateIfChanged(operands[1], operands[1]->meet(LayoutInfo(layoutB)));
+
+  if (layoutCD && operands.size() > 2)
+    propagateIfChanged(operands[2], operands[2]->meet(LayoutInfo(layoutCD)));
+
+  if (layoutAScale && operands.size() > 3)
+    propagateIfChanged(operands[3], operands[3]->meet(LayoutInfo(layoutAScale)));
+
+  if (layoutBScale && operands.size() > 4)
+    propagateIfChanged(operands[4], operands[4]->meet(LayoutInfo(layoutBScale)));
+}
     dpas.setLayoutBAttr(requiredBLayout);
     dpas.setLayoutCdAttr(requiredCDLayoutAttr);
     dpasALayout = LayoutInfo(requiredALayout);

>From 495c869b81faf00af2e6429b143ca20009955f1f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 24 Apr 2026 03:28:40 +0000
Subject: [PATCH 2/7] [mlir][XeGPU] Implement layout propagation for DpasMx
 operation

This commit implements complete layout propagation support for the
DpasMx operation, including support for scale_a and scale_b operands.

Key changes:
1. Implemented visitDpasMxOp in XeGPUPropagateLayout.cpp to compute
   and propagate layouts for all operands (a, b, acc, scale_a, scale_b)
2. Fixed createScaleLayout in XeGPULayoutImpl.cpp to properly handle
   scale layouts for inst_data and lane layout kinds
3. Added comprehensive tests for inst_data and lane layout propagation

The implementation follows the same pattern as regular DPAS layout
propagation but extends it to handle the additional scale operands.
Scale layouts are derived from the parent operand layouts and adapted
to match the scale tensor dimensions.

Tests added:
- propagate-layout-inst-data.mlir: dpas_mx with inst_data layouts
- propagate-layout.mlir: dpas_mx with lane layouts

Co-Authored-By: Claude Sonnet 4.5 <noreply at anthropic.com>
---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 468 ++++++++++--------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 137 +++--
 .../XeGPU/propagate-layout-inst-data.mlir     |  47 ++
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  42 ++
 4 files changed, 448 insertions(+), 246 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 8f07ddf09e9a4..b435a9dc112f0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1179,6 +1179,122 @@ getValidLayouts(ArrayRef<int64_t> wgShape, ArrayRef<int64_t> instData,
   return candidates;
 }
 
+/// Helper function to compute inst_data vectors for DPAS operands A, B, and
+/// C/D.
+static std::optional<std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
+                                SmallVector<int64_t>>>
+getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy,
+                       const xegpu::uArch::uArch *uArch) {
+  const int subgroupSize = uArch->getSubgroupSize();
+  const auto *uArchInstruction =
+      dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+          xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+  const unsigned dataALen = aTy.getShape().front();
+  auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
+  const int maxALen =
+      xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
+
+  const unsigned dataBLen = bTy.getShape().back();
+  auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
+  const int maxBLen =
+      xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
+
+  auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
+  const int maxCLen =
+      xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
+  if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
+    return std::nullopt;
+
+  SmallVector<int64_t> instDataA(aTy.getRank(), 1);
+  instDataA[aTy.getRank() - 2] = maxALen;
+  instDataA[aTy.getRank() - 1] = subgroupSize;
+  SmallVector<int64_t> instDataB(bTy.getRank(), 1);
+  instDataB[bTy.getRank() - 2] = subgroupSize;
+  instDataB[bTy.getRank() - 1] = maxBLen;
+  SmallVector<int64_t> instDataCD(cdTy.getRank(), 1);
+  instDataCD[cdTy.getRank() - 2] = maxALen;
+  instDataCD[cdTy.getRank() - 1] = maxCLen;
+  return std::make_tuple(instDataA, instDataB, instDataCD);
+}
+
+/// Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
+/// Returns the three layouts if successful, nullopt otherwise.
+static std::optional<std::tuple<xegpu::DistributeLayoutAttr,
+                                xegpu::DistributeLayoutAttr,
+                                xegpu::DistributeLayoutAttr>>
+getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy,
+                         VectorType bTy, VectorType cdTy,
+                         xegpu::DistributeLayoutAttr consumerLayout, int numSg,
+                         const xegpu::uArch::uArch *uArch) {
+  auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
+  if (!instDataVecs)
+    return std::nullopt;
+  auto [instDataA, instDataB, instDataCD] = *instDataVecs;
+  assert(instDataA.size() == 2 && instDataB.size() == 2 &&
+         instDataCD.size() == 2 &&
+         "Sg layout creation expects valid 2D inst data");
+
+  std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
+  if (consumerLayout && consumerLayout.isForWorkgroup()) {
+    SmallVector<int64_t> sgLayoutD =
+        consumerLayout.getEffectiveSgLayoutAsInt();
+    consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
+  }
+
+  // Get all valid layouts for A, B and C/D operands
+  auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
+  auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
+  auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
+  if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
+    return std::nullopt;
+
+  // Pick the best subgroup layout
+  llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
+  llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
+                                              layoutsCD.end());
+  std::optional<LayoutRepresentation> bestPick;
+  for (auto &sgLayout : layoutsB) {
+    if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
+      if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
+        bestPick = sgLayout;
+        break;
+      }
+      if (!bestPick)
+        bestPick = sgLayout;
+    }
+  }
+  if (!bestPick)
+    return std::nullopt;
+
+  SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
+                               static_cast<int>(bestPick->second)};
+  SmallVector<int> sgDataA = {
+      static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
+      static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
+  SmallVector<int> sgDataB = {
+      static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
+      static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
+  SmallVector<int> sgDataCD = {
+      static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
+      static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
+
+  auto dpasALayout = xegpu::LayoutAttr::get(
+      context, DenseI32ArrayAttr::get(context, sgLayout),
+      DenseI32ArrayAttr::get(context, sgDataA), nullptr, nullptr, nullptr,
+      nullptr);
+  auto dpasBLayout = xegpu::LayoutAttr::get(
+      context, DenseI32ArrayAttr::get(context, sgLayout),
+      DenseI32ArrayAttr::get(context, sgDataB), nullptr, nullptr, nullptr,
+      nullptr);
+  auto dpasCDLayout = xegpu::LayoutAttr::get(
+      context, DenseI32ArrayAttr::get(context, sgLayout),
+      DenseI32ArrayAttr::get(context, sgDataCD), nullptr, nullptr, nullptr,
+      nullptr);
+
+  return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
+}
+
 /// Sets up the anchor layouts for dpas operands (A, B, and C/D).
 /// The numSg and consumerLayout (optional) are only used by sg layout
 /// creation.
@@ -1194,122 +1310,13 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
       dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
           xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
 
-  auto getInstDataVectors = [&]()
-      -> std::optional<std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
-                                  SmallVector<int64_t>>> {
-    const int subgroupSize = uArch->getSubgroupSize();
-    const unsigned dataALen = aTy.getShape().front();
-    auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
-    const int maxALen =
-        xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
-
-    const unsigned dataBLen = bTy.getShape().back();
-    auto supportedBLen = uArchInstruction->getSupportedN(bTy.getElementType());
-    const int maxBLen =
-        xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
-
-    auto supportedCLen = uArchInstruction->getSupportedN(cdTy.getElementType());
-    const int maxCLen =
-        xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedCLen));
-    if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
-      return std::nullopt;
-
-    SmallVector<int64_t> instDataA(aTy.getRank(), 1);
-    instDataA[aTy.getRank() - 2] = maxALen;
-    instDataA[aTy.getRank() - 1] = subgroupSize;
-    SmallVector<int64_t> instDataB(bTy.getRank(), 1);
-    instDataB[bTy.getRank() - 2] = subgroupSize;
-    instDataB[bTy.getRank() - 1] = maxBLen;
-    SmallVector<int64_t> instDataCD(cdTy.getRank(), 1);
-    instDataCD[cdTy.getRank() - 2] = maxALen;
-    instDataCD[cdTy.getRank() - 1] = maxCLen;
-    return std::make_tuple(instDataA, instDataB, instDataCD);
-  };
-
   if (layoutKind == xegpu::LayoutKind::Subgroup) {
     assert(numSg > 0 &&
            "Number of subgroups must be provided for sg layout creation.");
-    auto instDataVecs = getInstDataVectors();
-    if (!instDataVecs)
-      return std::nullopt;
-    auto [instDataA, instDataB, instDataCD] = *instDataVecs;
-    assert(instDataA.size() == 2 && instDataB.size() == 2 &&
-           instDataCD.size() == 2 &&
-           "Sg layout creation expects valid 2D inst data");
-
-    std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
-    if (consumerLayout && consumerLayout.isForWorkgroup()) {
-      SmallVector<int64_t> sgLayoutD =
-          consumerLayout.getEffectiveSgLayoutAsInt();
-      consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
-    }
-
-    // Step 1. Get all valid layouts for A, B and C/D operands.
-    // Order them from most balanced to least balanced.
-    auto layoutsA = getValidLayouts(aTy.getShape(), instDataA, numSg);
-    auto layoutsB = getValidLayouts(bTy.getShape(), instDataB, numSg);
-    auto layoutsCD = getValidLayouts(cdTy.getShape(), instDataCD, numSg);
-    if (layoutsA.empty() || layoutsB.empty() || layoutsCD.empty())
-      return std::nullopt;
-
-    // Step 2. If the consumer layout can be reused for all operands, that
-    // layout is chosen. Otherwise, pick the most balanced subgroup layout
-    // that is valid for A, B and C (if present) operands
-    llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
-    llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
-                                               layoutsCD.end());
-    std::optional<LayoutRepresentation> bestPick;
-    for (auto &sgLayout : layoutsB) {
-      if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
-        // Is in (A and B and CD) and matches consumer -> best pick
-        if (consumerSgLayout.has_value() && sgLayout == *consumerSgLayout) {
-          bestPick = sgLayout;
-          break;
-        }
-        // Is in (A and B and CD) layoutsB is ordered from most
-        // balanced to least. So the first one we see is the most balanced
-        // one, remember it and later only update if there is one that matches
-        // the consumer.
-        if (!bestPick)
-          bestPick = sgLayout;
-      }
-    }
-    // Step 3. If there is no subgroup layout compatible with A, B and C (if
-    // present) operands, we fail.
-    if (!bestPick)
-      return std::nullopt;
-    SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
-                                 static_cast<int>(bestPick->second)};
-    SmallVector<int> sgDataA = {
-        static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
-        static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
-    SmallVector<int> sgDataB = {
-        static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
-        static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
-    SmallVector<int> sgDataCD = {
-        static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
-        static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
-
-    auto dpasALayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, sgLayout),
-        DenseI32ArrayAttr::get(context, sgDataA),
-        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
-        /*lane_data =*/nullptr, /*order =*/nullptr);
-
-    auto dpasBLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, sgLayout),
-        DenseI32ArrayAttr::get(context, sgDataB),
-        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
-        /*lane_data =*/nullptr, /*order =*/nullptr);
-
-    auto dpasCDLayout = xegpu::LayoutAttr::get(
-        context, DenseI32ArrayAttr::get(context, sgLayout),
-        DenseI32ArrayAttr::get(context, sgDataCD),
-        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
-        /*lane_data =*/nullptr, /*order =*/nullptr);
-    return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
+    return getupDpasSubgroupLayouts(context, aTy, bTy, cdTy, consumerLayout,
+                                    numSg, uArch);
   } else if (layoutKind == xegpu::LayoutKind::InstData) {
-    auto instDataVecs = getInstDataVectors();
+    auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
     if (!instDataVecs)
       return std::nullopt;
     auto [instDataA, instDataB, instDataCD] = *instDataVecs;
@@ -1332,100 +1339,6 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
   return std::nullopt;
 }
 
-xegpu::DistributeLayoutAttr
-xegpu::inferSourceLayoutFromResult(OpOperand &operand,
-                                   xegpu::DistributeLayoutAttr resLayout) {
-  if (!resLayout)
-    return nullptr;
-  Operation *op = operand.getOwner();
-  unsigned idx = operand.getOperandNumber();
-
-  // For vector::BroadcastOp, infer the source layout from the result layout.
-  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
-    auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
-    if (!srcTy)
-      return nullptr;
-    return xegpu::inferBroadcastSourceLayout(
-        resLayout, broadcast.getResultVectorType().getShape(),
-        srcTy.getShape());
-  }
-
-  // For vector::MultiDimReductionOp, infer source layout from result layout
-  // using reduction dims. Acc operand is expected to have the same layout as
-  // the result.
-  if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
-    if (idx == 0) {
-      SmallVector<int64_t> reductionDims(reduction.getReductionDims());
-      return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
-    }
-    if (idx == 1)
-      return resLayout;
-  }
-
-  if (auto reduction = dyn_cast<vector::ReductionOp>(op))
-    return xegpu::inferReductionSourceLayout(resLayout);
-
-  // For vector::BitCastOp, infer source layout from result layout using
-  // element type bitwidths.
-  if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
-    int resElemBitWidth =
-        bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
-    int srcElemBitWidth =
-        bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
-    return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
-                                           srcElemBitWidth);
-  }
-
-  // For vector::ShapeCastOp, infer source layout from result layout using
-  // shapes.
-  if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
-    return xegpu::inferShapeCastSourceLayout(
-        resLayout, shapeCast.getResultVectorType().getShape(),
-        shapeCast.getSourceVectorType().getShape());
-  }
-
-  // For vector::InsertStridedSliceOp, infer source layout from result layout.
-  // Dest vector must have the same layout as the result.
-  if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
-    if (idx == 0) {
-      return xegpu::inferInsertStridedSliceSourceLayout(
-          resLayout, insertSlice.getDestVectorType().getShape(),
-          insertSlice.getSourceVectorType().getShape());
-    }
-    if (idx == 1)
-      return resLayout;
-  }
-
-  // For vector::TransposeOp, infer source layout from result layout using
-  // permutation.
-  if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
-    return xegpu::inferTransposeSourceLayout(resLayout,
-                                             transpose.getPermutation());
-  }
-
-  // For vector::ExtractStridedSliceOp, simply return result layout
-  if (dyn_cast<vector::ExtractStridedSliceOp>(op))
-    return resLayout;
-  // For elementwise operations, all operands must have the same layout as the
-  // result.
-  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
-    return resLayout;
-
-  return nullptr;
-}
-
-xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
-  Operation *op = operand.getOwner();
-  xegpu::DistributeLayoutAttr resLayout;
-  if (op->getNumResults() == 1)
-    resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
-  auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
-  if (inferredOperandLayout)
-    return inferredOperandLayout;
-  // By default, assume no layout conflict and return the current layout of
-  // the operand.
-  return xegpu::getDistributeLayoutAttr(operand.get());
-}
 /// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
 /// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
 /// creation.
@@ -1452,6 +1365,11 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
     // Calculate scaling factor by dividing parent shape by scale shape
     ArrayRef<int64_t> parentShape = parentTy.getShape();
     ArrayRef<int64_t> scaleShape = scaleTy.getShape();
+
+    // Scale shapes can be 1D or 2D, handle both cases
+    if (scaleShape.empty())
+      return nullptr;
+
     int64_t scaleFactor = parentShape.back() / scaleShape.back();
     int64_t rank = parentLayout.getRank();
     assert(rank == 2 && "dpas layouts must be two dimensions");
@@ -1467,19 +1385,38 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
     // Divide last dimension by scaling factor
     if (!sgData.empty())
       sgData.back() = sgData.back() / scaleFactor;
-    if (!instData.empty())
-      instData.back() = instData.back() / scaleFactor;
-
-    if (isBScale) {
-      // For B scale: lane_layout = [min(subgroupSize, scaleFactor), 1]
-      // lane_data = [1, number of scale elements]
-      laneLayout[rank - 2] =
-          std::min(static_cast<int64_t>(subgroupSize), scaleShape[rank-2]);
-      laneLayout[rank - 1] = 1;
+
+    // For inst_data only layouts (no lane info), create a simple inst_data layout for scales
+    // The scale dimensions are much smaller, so we use the scale shape directly
+    if (!instData.empty() && laneLayout.empty() && laneData.empty()) {
+      // For inst_data layout, create simple inst_data for the scale
+      SmallVector<int64_t> scaleInstData(rank, 1);
+      if (scaleShape.size() >= 2) {
+        scaleInstData[rank - 2] = scaleShape[rank - 2];
+        scaleInstData[rank - 1] = scaleShape[rank - 1];
+      } else if (scaleShape.size() == 1) {
+        scaleInstData[rank - 1] = scaleShape[0];
+      }
+      return xegpu::LayoutAttr::get(
+          context, nullptr, nullptr,
+          DenseI32ArrayAttr::get(
+              context, SmallVector<int>(scaleInstData.begin(), scaleInstData.end())),
+          nullptr, nullptr, order);
+    }
+
+    // Handle lane layout for subgroup/lane layout kinds
+    if (!laneLayout.empty() && !laneData.empty()) {
+      // For scales, lane_layout should match the scale shape dimensions
+      // and lane_data should be [1, 1] since each lane holds one scale element
+      if (scaleShape.size() >= 2) {
+        laneLayout[rank - 2] = scaleShape[rank - 2];
+        laneLayout[rank - 1] = scaleShape[rank - 1];
+      } else if (scaleShape.size() == 1) {
+        laneLayout[rank - 2] = 1;
+        laneLayout[rank - 1] = scaleShape[0];
+      }
       laneData[rank - 2] = 1;
-      laneData[rank - 1] = scaleShape.back();
-    } else {
-      laneData.back() = scaleShape.back();
+      laneData[rank - 1] = 1;
     }
 
     return xegpu::LayoutAttr::get(
@@ -1577,3 +1514,98 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
   }
   return std::nullopt;
 }
+
+xegpu::DistributeLayoutAttr
+xegpu::inferSourceLayoutFromResult(OpOperand &operand,
+                                   xegpu::DistributeLayoutAttr resLayout) {
+  if (!resLayout)
+    return nullptr;
+  Operation *op = operand.getOwner();
+  unsigned idx = operand.getOperandNumber();
+
+  // For vector::BroadcastOp, infer the source layout from the result layout.
+  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) {
+    auto srcTy = dyn_cast<VectorType>(broadcast.getSourceType());
+    if (!srcTy)
+      return nullptr;
+    return xegpu::inferBroadcastSourceLayout(
+        resLayout, broadcast.getResultVectorType().getShape(),
+        srcTy.getShape());
+  }
+
+  // For vector::MultiDimReductionOp, infer source layout from result layout
+  // using reduction dims. Acc operand is expected to have the same layout as
+  // the result.
+  if (auto reduction = dyn_cast<vector::MultiDimReductionOp>(op)) {
+    if (idx == 0) {
+      SmallVector<int64_t> reductionDims(reduction.getReductionDims());
+      return xegpu::inferMultiReductionSourceLayout(resLayout, reductionDims);
+    }
+    if (idx == 1)
+      return resLayout;
+  }
+
+  if (auto reduction = dyn_cast<vector::ReductionOp>(op))
+    return xegpu::inferReductionSourceLayout(resLayout);
+
+  // For vector::BitCastOp, infer source layout from result layout using
+  // element type bitwidths.
+  if (auto bitcast = dyn_cast<vector::BitCastOp>(op)) {
+    int resElemBitWidth =
+        bitcast.getResultVectorType().getElementType().getIntOrFloatBitWidth();
+    int srcElemBitWidth =
+        bitcast.getSourceVectorType().getElementType().getIntOrFloatBitWidth();
+    return xegpu::inferBitCastSourceLayout(resLayout, resElemBitWidth,
+                                           srcElemBitWidth);
+  }
+
+  // For vector::ShapeCastOp, infer source layout from result layout using
+  // shapes.
+  if (auto shapeCast = dyn_cast<vector::ShapeCastOp>(op)) {
+    return xegpu::inferShapeCastSourceLayout(
+        resLayout, shapeCast.getResultVectorType().getShape(),
+        shapeCast.getSourceVectorType().getShape());
+  }
+
+  // For vector::InsertStridedSliceOp, infer source layout from result layout.
+  // Dest vector must have the same layout as the result.
+  if (auto insertSlice = dyn_cast<vector::InsertStridedSliceOp>(op)) {
+    if (idx == 0) {
+      return xegpu::inferInsertStridedSliceSourceLayout(
+          resLayout, insertSlice.getDestVectorType().getShape(),
+          insertSlice.getSourceVectorType().getShape());
+    }
+    if (idx == 1)
+      return resLayout;
+  }
+
+  // For vector::TransposeOp, infer source layout from result layout using
+  // permutation.
+  if (auto transpose = dyn_cast<vector::TransposeOp>(op)) {
+    return xegpu::inferTransposeSourceLayout(resLayout,
+                                             transpose.getPermutation());
+  }
+
+  // For vector::ExtractStridedSliceOp, simply return result layout
+  if (dyn_cast<vector::ExtractStridedSliceOp>(op))
+    return resLayout;
+  // For elementwise operations, all operands must have the same layout as the
+  // result.
+  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
+    return resLayout;
+
+  return nullptr;
+}
+
+xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
+  Operation *op = operand.getOwner();
+  xegpu::DistributeLayoutAttr resLayout;
+  if (op->getNumResults() == 1)
+    resLayout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+  auto inferredOperandLayout = inferSourceLayoutFromResult(operand, resLayout);
+  if (inferredOperandLayout)
+    return inferredOperandLayout;
+  // By default, assume no layout conflict and return the current layout of
+  // the operand.
+  return xegpu::getDistributeLayoutAttr(operand.get());
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index f1709e95d34a9..f6df68c9ca0ae 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -810,6 +810,18 @@ void LayoutInfoPropagation::visitDpasOp(
     std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
 
     dpas.setLayoutAAttr(requiredALayout);
+    dpas.setLayoutBAttr(requiredBLayout);
+    dpas.setLayoutCdAttr(requiredCDLayoutAttr);
+    dpasALayout = LayoutInfo(requiredALayout);
+    dpasBLayout = LayoutInfo(requiredBLayout);
+    dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
+  }
+  propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+  propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
+  if (operands.size() > 2)
+    propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
+}
+
 
 /// Propagate layout for DpasMxOp operands using the layout attributes.
 /// DpasMxOp has operands: a, b, acc (optional), scale_a (optional), scale_b (optional)
@@ -817,42 +829,111 @@ void LayoutInfoPropagation::visitDpasMxOp(
     xegpu::DpasMxOp dpasMx, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
 
-  // Get the layout attributes from the operation
-  xegpu::DistributeLayoutAttr layoutA = dpasMx.getLayoutAAttr();
-  xegpu::DistributeLayoutAttr layoutB = dpasMx.getLayoutBAttr();
-  xegpu::DistributeLayoutAttr layoutCD = dpasMx.getLayoutCdAttr();
-  xegpu::DistributeLayoutAttr layoutAScale = dpasMx.getLayoutAScaleAttr();
-  xegpu::DistributeLayoutAttr layoutBScale = dpasMx.getLayoutBScaleAttr();
+  // Initialize layout variables
+  LayoutInfo dpasMxALayout, dpasMxBLayout, dpasMxCDLayout;
+  LayoutInfo dpasMxAScaleLayout, dpasMxBScaleLayout;
+
+  // Get existing layout attributes from the operation
+  xegpu::DistributeLayoutAttr anchorLayoutA = dpasMx.getLayoutAAttr();
+  xegpu::DistributeLayoutAttr anchorLayoutB = dpasMx.getLayoutBAttr();
+  xegpu::DistributeLayoutAttr anchorLayoutCD = dpasMx.getLayoutCdAttr();
+
+  // Check if all layouts are already set
+  if (anchorLayoutA && anchorLayoutB && anchorLayoutCD &&
+      hasParamsOfLayoutKind(anchorLayoutA) &&
+      hasParamsOfLayoutKind(anchorLayoutB) &&
+      hasParamsOfLayoutKind(anchorLayoutCD)) {
+    dpasMxALayout = LayoutInfo(anchorLayoutA);
+    dpasMxBLayout = LayoutInfo(anchorLayoutB);
+    dpasMxCDLayout = LayoutInfo(anchorLayoutCD);
+
+    // Get scale layouts if available
+    xegpu::DistributeLayoutAttr anchorLayoutAScale = dpasMx.getLayoutAScaleAttr();
+    xegpu::DistributeLayoutAttr anchorLayoutBScale = dpasMx.getLayoutBScaleAttr();
+    if (anchorLayoutAScale)
+      dpasMxAScaleLayout = LayoutInfo(anchorLayoutAScale);
+    if (anchorLayoutBScale)
+      dpasMxBScaleLayout = LayoutInfo(anchorLayoutBScale);
+  } else {
+    // Need to compute layouts
+    const uArch *uArch = getUArch(getChipStr(dpasMx).value_or(""));
+    if (!uArch)
+      return;
 
-  // Propagate layouts to operands based on their positions:
-  // operands[0] = a, operands[1] = b, operands[2] = acc (optional),
-  // operands[3] = scale_a (optional), operands[4] = scale_b (optional)
+    VectorType aTy = dpasMx.getAType();
+    VectorType bTy = dpasMx.getBType();
+    VectorType cdTy = dpasMx.getResultType();
 
-  if (layoutA && operands.size() > 0)
-    propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(layoutA)));
+    // Get scale types if present
+    std::optional<VectorType> aScaleTy = std::nullopt;
+    std::optional<VectorType> bScaleTy = std::nullopt;
+    Value scaleA = dpasMx.getScaleA();
+    Value scaleB = dpasMx.getScaleB();
+    if (scaleA)
+      aScaleTy = dyn_cast<VectorType>(scaleA.getType());
+    if (scaleB)
+      bScaleTy = dyn_cast<VectorType>(scaleB.getType());
 
-  if (layoutB && operands.size() > 1)
-    propagateIfChanged(operands[1], operands[1]->meet(LayoutInfo(layoutB)));
+    xegpu::DistributeLayoutAttr consumerLayoutAttr = nullptr;
+    xegpu::DistributeLayoutAttr requiredCDLayoutAttr, requiredALayout,
+        requiredBLayout, requiredAScaleLayout, requiredBScaleLayout;
 
-  if (layoutCD && operands.size() > 2)
-    propagateIfChanged(operands[2], operands[2]->meet(LayoutInfo(layoutCD)));
+    int numSg = 0;
+    if (layoutKind == xegpu::LayoutKind::Subgroup) {
+      LayoutInfo consumerLayout = results[0]->getValue();
+      if (!consumerLayout.isAssigned())
+        return;
+      consumerLayoutAttr =
+          dyn_cast<xegpu::DistributeLayoutAttr>(consumerLayout.get());
+      auto numSgOrErr = getNumSg(dpasMx, uArch->getSubgroupSize());
+      if (failed(numSgOrErr)) {
+        dpasMx.emitWarning(
+            "Unable to determine the number of subgroups for the operation.");
+        return;
+      }
+      numSg = numSgOrErr.value();
+    }
 
-  if (layoutAScale && operands.size() > 3)
-    propagateIfChanged(operands[3], operands[3]->meet(LayoutInfo(layoutAScale)));
+    auto layouts = xegpu::setupDpasMxLayout(layoutKind, aTy, bTy, cdTy,
+                                            aScaleTy, bScaleTy,
+                                            consumerLayoutAttr, numSg, uArch);
+    if (!layouts.has_value()) {
+      dpasMx.emitWarning(
+          "Failed to determine required layouts for DPAS_MX operands.");
+      return;
+    }
 
-  if (layoutBScale && operands.size() > 4)
-    propagateIfChanged(operands[4], operands[4]->meet(LayoutInfo(layoutBScale)));
-}
-    dpas.setLayoutBAttr(requiredBLayout);
-    dpas.setLayoutCdAttr(requiredCDLayoutAttr);
-    dpasALayout = LayoutInfo(requiredALayout);
-    dpasBLayout = LayoutInfo(requiredBLayout);
-    dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
+    std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr,
+             requiredAScaleLayout, requiredBScaleLayout) = *layouts;
+
+    dpasMx.setLayoutAAttr(requiredALayout);
+    dpasMx.setLayoutBAttr(requiredBLayout);
+    dpasMx.setLayoutCdAttr(requiredCDLayoutAttr);
+    if (requiredAScaleLayout)
+      dpasMx.setLayoutAScaleAttr(requiredAScaleLayout);
+    if (requiredBScaleLayout)
+      dpasMx.setLayoutBScaleAttr(requiredBScaleLayout);
+
+    dpasMxALayout = LayoutInfo(requiredALayout);
+    dpasMxBLayout = LayoutInfo(requiredBLayout);
+    dpasMxCDLayout = LayoutInfo(requiredCDLayoutAttr);
+    if (requiredAScaleLayout)
+      dpasMxAScaleLayout = LayoutInfo(requiredAScaleLayout);
+    if (requiredBScaleLayout)
+      dpasMxBScaleLayout = LayoutInfo(requiredBScaleLayout);
   }
-  propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
-  propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
+
+  // Propagate layouts to operands
+  // operands[0] = a, operands[1] = b, operands[2] = acc (optional),
+  // operands[3] = scale_a (optional), operands[4] = scale_b (optional)
+  propagateIfChanged(operands[0], operands[0]->meet(dpasMxALayout));
+  propagateIfChanged(operands[1], operands[1]->meet(dpasMxBLayout));
   if (operands.size() > 2)
-    propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
+    propagateIfChanged(operands[2], operands[2]->meet(dpasMxCDLayout));
+  if (operands.size() > 3 && dpasMxAScaleLayout.isAssigned())
+    propagateIfChanged(operands[3], operands[3]->meet(dpasMxAScaleLayout));
+  if (operands.size() > 4 && dpasMxBScaleLayout.isAssigned())
+    propagateIfChanged(operands[4], operands[4]->meet(dpasMxBScaleLayout));
 }
 
 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 5a95185c8de48..891aac22eea2c 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -391,3 +391,50 @@ gpu.module @test{
     gpu.return
   }
 }
+// =============================================================================
+// Test for propagate-layout-inst-data.mlir
+// Add this test to the end of that file
+// =============================================================================
+
+// -----
+
+// CHECK-LABEL: func.func @dpas_mx_f8e5m2
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<16x64xf8E5M2>, %[[ARG1:[0-9a-zA-Z]+]]: memref<64x32xf8E5M2>, %[[ARG2:[0-9a-zA-Z]+]]: memref<16x32xbf16>
+// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: memref<16x2xf8E8M0FNU>, %[[ARG4:[0-9a-zA-Z]+]]: memref<2x32xf8E8M0FNU>
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<16x32xbf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<16x64xf8E5M2> -> !xegpu.tensor_desc<16x64xf8E5M2, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<64x32xf8E5M2> -> !xegpu.tensor_desc<64x32xf8E5M2, #xegpu.layout<inst_data = [16, 16]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<16x64xf8E5M2, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x64xf8E5M2>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<64x32xf8E5M2, #xegpu.layout<inst_data = [16, 16]>> -> vector<64x32xf8E5M2>
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG3]][{{.*}}] : memref<16x2xf8E8M0FNU> -> !xegpu.tensor_desc<16x2xf8E8M0FNU, #xegpu.layout<inst_data = [16, 2]>>
+// CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<inst_data = [16, 2]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<16x2xf8E8M0FNU, #xegpu.layout<inst_data = [16, 2]>> -> vector<16x2xf8E8M0FNU>
+// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG4]][{{.*}}] : memref<2x32xf8E8M0FNU> -> !xegpu.tensor_desc<2x32xf8E8M0FNU, #xegpu.layout<inst_data = [2, 32]>>
+// CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<inst_data = [2, 32]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<2x32xf8E8M0FNU, #xegpu.layout<inst_data = [2, 32]>> -> vector<2x32xf8E8M0FNU>
+// CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T2]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
+// CHECK-SAME: {layout_a = #xegpu.layout<inst_data = [8, 16]>, layout_a_scale = #xegpu.layout<inst_data = [16, 2]>, layout_b = #xegpu.layout<inst_data = [16, 16]>, layout_b_scale = #xegpu.layout<inst_data = [2, 32]>, layout_cd = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: vector<16x64xf8E5M2>, vector<64x32xf8E5M2>, vector<16x32xbf16>, vector<16x2xf8E8M0FNU>, vector<2x32xf8E8M0FNU> -> vector<16x32xbf16>
+// CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<16x32xbf16> -> !xegpu.tensor_desc<16x32xbf16, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK: xegpu.store_nd %[[T8]], %[[T9]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x32xbf16>, !xegpu.tensor_desc<16x32xbf16, #xegpu.layout<inst_data = [8, 16]>>
+gpu.module @test {
+func.func @dpas_mx_f8e5m2(%arg0: memref<16x64xf8E5M2>, %arg1: memref<64x32xf8E5M2>, %arg2: memref<16x32xbf16>,
+    %arg3: memref<16x2xf8E8M0FNU>, %arg4: memref<2x32xf8E8M0FNU>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<16x32xbf16>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x64xf8E5M2> -> !xegpu.tensor_desc<16x64xf8E5M2>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<64x32xf8E5M2> -> !xegpu.tensor_desc<64x32xf8E5M2>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<16x64xf8E5M2> -> vector<16x64xf8E5M2>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<64x32xf8E5M2> -> vector<64x32xf8E5M2>
+  %4 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<16x2xf8E8M0FNU> -> !xegpu.tensor_desc<16x2xf8E8M0FNU>
+  %5 = xegpu.load_nd %4  : !xegpu.tensor_desc<16x2xf8E8M0FNU> -> vector<16x2xf8E8M0FNU>
+  %6 = xegpu.create_nd_tdesc %arg4[%c0, %c0] : memref<2x32xf8E8M0FNU> -> !xegpu.tensor_desc<2x32xf8E8M0FNU>
+  %7 = xegpu.load_nd %6  : !xegpu.tensor_desc<2x32xf8E8M0FNU> -> vector<2x32xf8E8M0FNU>
+  %8 = xegpu.dpas_mx %2, %3, %cst scale_a = %5 scale_b = %7 : vector<16x64xf8E5M2>, vector<64x32xf8E5M2>, vector<16x32xbf16>, vector<16x2xf8E8M0FNU>, vector<2x32xf8E8M0FNU> -> vector<16x32xbf16>
+  %9 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<16x32xbf16> -> !xegpu.tensor_desc<16x32xbf16>
+  xegpu.store_nd %8, %9  : vector<16x32xbf16>, !xegpu.tensor_desc<16x32xbf16>
+  return
+}
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index c87dbf3ec2108..5cec3cf54b655 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -990,3 +990,45 @@ gpu.module @test{
     return
   }
 }
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @dpas_mx_f8e5m2
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x32xf8E5M2>, %[[ARG1:[0-9a-zA-Z]+]]: memref<32x16xf8E5M2>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xbf16>
+// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: memref<8x1xf8E8M0FNU>, %[[ARG4:[0-9a-zA-Z]+]]: memref<1x16xf8E8M0FNU>
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xbf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x32xf8E5M2> -> !xegpu.tensor_desc<8x32xf8E5M2, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<32x16xf8E5M2> -> !xegpu.tensor_desc<32x16xf8E5M2, #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<8x32xf8E5M2, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>> -> vector<8x32xf8E5M2>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<32x16xf8E5M2, #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>> -> vector<32x16xf8E5M2>
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG3]][{{.*}}] : memref<8x1xf8E8M0FNU> -> !xegpu.tensor_desc<8x1xf8E8M0FNU, #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>>
+// CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<8x1xf8E8M0FNU, #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>> -> vector<8x1xf8E8M0FNU>
+// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG4]][{{.*}}] : memref<1x16xf8E8M0FNU> -> !xegpu.tensor_desc<1x16xf8E8M0FNU, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<1x16xf8E8M0FNU, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<1x16xf8E8M0FNU>
+// CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T2]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
+// CHECK-SAME: {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_a_scale = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>, layout_b_scale = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: vector<8x32xf8E5M2>, vector<32x16xf8E5M2>, vector<8x16xbf16>, vector<8x1xf8E8M0FNU>, vector<1x16xf8E8M0FNU> -> vector<8x16xbf16>
+// CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: xegpu.store_nd %[[T8]], %[[T9]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @dpas_mx_f8e5m2(%arg0: memref<8x32xf8E5M2>, %arg1: memref<32x16xf8E5M2>, %arg2: memref<8x16xbf16>,
+    %arg3: memref<8x1xf8E8M0FNU>, %arg4: memref<1x16xf8E8M0FNU>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xbf16>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x32xf8E5M2> -> !xegpu.tensor_desc<8x32xf8E5M2>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<32x16xf8E5M2> -> !xegpu.tensor_desc<32x16xf8E5M2>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x32xf8E5M2> -> vector<8x32xf8E5M2>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<32x16xf8E5M2> -> vector<32x16xf8E5M2>
+  %4 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x1xf8E8M0FNU> -> !xegpu.tensor_desc<8x1xf8E8M0FNU>
+  %5 = xegpu.load_nd %4  : !xegpu.tensor_desc<8x1xf8E8M0FNU> -> vector<8x1xf8E8M0FNU>
+  %6 = xegpu.create_nd_tdesc %arg4[%c0, %c0] : memref<1x16xf8E8M0FNU> -> !xegpu.tensor_desc<1x16xf8E8M0FNU>
+  %7 = xegpu.load_nd %6  : !xegpu.tensor_desc<1x16xf8E8M0FNU> -> vector<1x16xf8E8M0FNU>
+  %8 = xegpu.dpas_mx %2, %3, %cst scale_a = %5 scale_b = %7 : vector<8x32xf8E5M2>, vector<32x16xf8E5M2>, vector<8x16xbf16>, vector<8x1xf8E8M0FNU>, vector<1x16xf8E8M0FNU> -> vector<8x16xbf16>
+  %9 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+  xegpu.store_nd %8, %9  : vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16>
+  return
+}
+}

>From 1b63a66b2882a7ee49b28e6249f3f548677fca9b Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 24 Apr 2026 15:01:10 +0000
Subject: [PATCH 3/7] add tests

---
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |  6 ++
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 13 ++-
 mlir/test/Dialect/XeGPU/ops.mlir              |  6 +-
 .../XeGPU/propagate-layout-subgroup.mlir      | 82 +++++++++++++++++++
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 43 ++++++++++
 5 files changed, 144 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index c68e7334434d6..5defec222347f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -39,6 +39,12 @@ LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
 
 LogicalResult resolveLayoutConflicts(Operation *target);
 
+/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
+/// OpResult of of the given operation. If the operation contains regions, it is
+/// also applied recursively to the contained operations operation.
+/// TODO: To be replaced by recoverTemporaryLayouts()
+void recoverTemporaryLayoutsDeprecated(Operation *op);
+
 /// Attach layout attributes to all vector-type operands of operations within
 /// the given operation's nested region. Reports an error if any vector operand
 /// lacks a layout attribute.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index b435a9dc112f0..737596daedb23 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1382,9 +1382,16 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
     SmallVector<int64_t> laneData = parentLayout.getEffectiveLaneDataAsInt();
     auto order = parentLayout.getOrder();
 
-    // Divide last dimension by scaling factor
-    if (!sgData.empty())
-      sgData.back() = sgData.back() / scaleFactor;
+    // For subgroup layouts, compute sg_data based on scale shape / sg_layout
+    if (!sgLayout.empty() && !sgData.empty()) {
+      // sg_data = scale_shape / sg_layout
+      if (scaleShape.size() >= 2) {
+        sgData[rank - 2] = scaleShape[rank - 2] / sgLayout[rank - 2];
+        sgData[rank - 1] = scaleShape[rank - 1] / sgLayout[rank - 1];
+      } else if (scaleShape.size() == 1) {
+        sgData[rank - 1] = scaleShape[0] / sgLayout[rank - 1];
+      }
+    }
 
     // For inst_data only layouts (no lane info), create a simple inst_data layout for scales
     // The scale dimensions are much smaller, so we use the scale shape directly
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index b32e297b60fc8..faa85f78adf63 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -646,9 +646,9 @@ gpu.func @truncf(%a: vector<8x16xf16>) {
 }
 
 // CHECK-LABEL: gpu.func @dpas_mx
-gpu.func @dpas_mx(%a : vector<8x16xf8E5M2>, %b: vector<16x16xf8E5M2>, %acc: vector<8x16xbf16>) {
-  // CHECK: %{{.+}} = xegpu.dpas_mx %{{.+}}, %{{.+}}, %{{.+}} : vector<8x16xf8E5M2>, vector<16x16xf8E5M2>, vector<8x16xbf16> -> vector<8x16xbf16>
-  %1 = xegpu.dpas_mx %a, %b, %acc : vector<8x16xf8E5M2>, vector<16x16xf8E5M2>, vector<8x16xbf16> -> vector<8x16xbf16>
+gpu.func @dpas_mx(%a : vector<8x32xf8E5M2>, %b: vector<32x16xf8E5M2>, %acc: vector<8x16xbf16>, %a_scale: vector<8x1xf8E8M0FNU>, %b_scale: vector<1x16xf8E8M0FNU>) {
+  // CHECK: %{{.+}} = xegpu.dpas_mx %{{.+}}, %{{.+}}, %{{.+}} scale_a = %{{.+}} scale_b = %{{.+}} : vector<8x32xf8E5M2>, vector<32x16xf8E5M2>, vector<8x16xbf16>, vector<8x1xf8E8M0FNU>, vector<1x16xf8E8M0FNU> -> vector<8x16xbf16>
+  %1 = xegpu.dpas_mx %a, %b, %acc scale_a = %a_scale scale_b = %b_scale : vector<8x32xf8E5M2>, vector<32x16xf8E5M2>, vector<8x16xbf16>, vector<8x1xf8E8M0FNU>, vector<1x16xf8E8M0FNU> -> vector<8x16xbf16>
   gpu.return
 }
 
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index d8a07d7c85a6c..03cf4210700f1 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -350,3 +350,85 @@ gpu.module @test {
     gpu.return
   }
 }
+
+// -----
+
+gpu.module @test {
+  // CHECK-LABEL: dpas_mx
+  gpu.func @dpas_mx(%arg0: memref<128x512xf8E5M2>, %arg1: memref<512x256xf8E5M2>, %arg2: memref<128x256xbf16>,
+      %arg3: memref<128x16xf8E8M0FNU>, %arg4: memref<16x256xf8E8M0FNU>) kernel attributes
+      {known_block_size = array<i32: 1, 64, 16>} {
+    %cst = arith.constant dense<0.000000e+00> : vector<128x256xbf16>
+    %tdesc_a = xegpu.create_nd_tdesc %arg0 : memref<128x512xf8E5M2> -> !xegpu.tensor_desc<128x512xf8E5M2>
+    %load_a =  xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<128x512xf8E5M2> -> vector<128x512xf8E5M2>
+    %tdesc_b = xegpu.create_nd_tdesc %arg1 : memref<512x256xf8E5M2> -> !xegpu.tensor_desc<512x256xf8E5M2>
+    %load_b =  xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<512x256xf8E5M2> -> vector<512x256xf8E5M2>
+    %tdesc_a_scale = xegpu.create_nd_tdesc %arg3 : memref<128x16xf8E8M0FNU> -> !xegpu.tensor_desc<128x16xf8E8M0FNU>
+    %load_a_scale =  xegpu.load_nd %tdesc_a_scale : !xegpu.tensor_desc<128x16xf8E8M0FNU> -> vector<128x16xf8E8M0FNU>
+    %tdesc_b_scale = xegpu.create_nd_tdesc %arg4 : memref<16x256xf8E8M0FNU> -> !xegpu.tensor_desc<16x256xf8E8M0FNU>
+    %load_b_scale =  xegpu.load_nd %tdesc_b_scale : !xegpu.tensor_desc<16x256xf8E8M0FNU> -> vector<16x256xf8E8M0FNU>
+    %dpas_mx = xegpu.dpas_mx %load_a, %load_b, %cst scale_a = %load_a_scale scale_b = %load_b_scale : vector<128x512xf8E5M2>, vector<512x256xf8E5M2>, vector<128x256xbf16>, vector<128x16xf8E8M0FNU>, vector<16x256xf8E8M0FNU> -> vector<128x256xbf16>
+    %tdesc_cd = xegpu.create_nd_tdesc %arg2 : memref<128x256xbf16> -> !xegpu.tensor_desc<128x256xbf16>
+    xegpu.store_nd %dpas_mx, %tdesc_cd[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>}> : vector<128x256xbf16>, !xegpu.tensor_desc<128x256xbf16>
+    gpu.return
+  }
+  // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>} dense<0.000000e+00> : vector<128x256xbf16>
+  // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x512xf8E5M2> -> !xegpu.tensor_desc<128x512xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>>
+  // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<128x512xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>> -> vector<128x512xf8E5M2>
+  // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<512x256xf8E5M2> -> !xegpu.tensor_desc<512x256xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 32]>>
+  // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 32]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<512x256xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 32]>> -> vector<512x256xf8E5M2>
+  // CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf8E8M0FNU> -> !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>>
+  // CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>> -> vector<128x16xf8E8M0FNU>
+  // CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x256xf8E8M0FNU> -> !xegpu.tensor_desc<16x256xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 32]>>
+  // CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 32]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<16x256xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 32]>> -> vector<16x256xf8E8M0FNU>
+  // CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T1]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>, layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>, layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 32]>, layout_b_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 32]>, layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>} :
+  // CHECK-SAME: vector<128x512xf8E5M2>, vector<512x256xf8E5M2>, vector<128x256xbf16>, vector<128x16xf8E8M0FNU>, vector<16x256xf8E8M0FNU> -> vector<128x256xbf16>
+  // CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x256xbf16> -> !xegpu.tensor_desc<128x256xbf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>>
+  // CHECK: xegpu.store_nd %[[T8]], %[[T9]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>}> : vector<128x256xbf16>, !xegpu.tensor_desc<128x256xbf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>>
+}
+
+// -----
+
+gpu.module @test {
+  // CHECK-LABEL: dpas_mx_fp4
+  gpu.func @dpas_mx_fp4(%arg0: memref<128x512xf4E2M1FN>, %arg1: memref<512x128xf4E2M1FN>, %arg2: memref<128x128xf32>,
+      %arg3: memref<128x16xf8E8M0FNU>, %arg4: memref<16x128xf8E8M0FNU>) kernel attributes
+      {known_block_size = array<i32: 1, 64, 16>} {
+    %cst = arith.constant dense<0.000000e+00> : vector<128x128xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %arg0 : memref<128x512xf4E2M1FN> -> !xegpu.tensor_desc<128x512xf4E2M1FN>
+    %load_a =  xegpu.load_nd %tdesc_a : !xegpu.tensor_desc<128x512xf4E2M1FN> -> vector<128x512xf4E2M1FN>
+    %tdesc_b = xegpu.create_nd_tdesc %arg1 : memref<512x128xf4E2M1FN> -> !xegpu.tensor_desc<512x128xf4E2M1FN>
+    %load_b =  xegpu.load_nd %tdesc_b : !xegpu.tensor_desc<512x128xf4E2M1FN> -> vector<512x128xf4E2M1FN>
+    %tdesc_a_scale = xegpu.create_nd_tdesc %arg3 : memref<128x16xf8E8M0FNU> -> !xegpu.tensor_desc<128x16xf8E8M0FNU>
+    %load_a_scale =  xegpu.load_nd %tdesc_a_scale : !xegpu.tensor_desc<128x16xf8E8M0FNU> -> vector<128x16xf8E8M0FNU>
+    %tdesc_b_scale = xegpu.create_nd_tdesc %arg4 : memref<16x128xf8E8M0FNU> -> !xegpu.tensor_desc<16x128xf8E8M0FNU>
+    %load_b_scale =  xegpu.load_nd %tdesc_b_scale : !xegpu.tensor_desc<16x128xf8E8M0FNU> -> vector<16x128xf8E8M0FNU>
+    %dpas_mx = xegpu.dpas_mx %load_a, %load_b, %cst scale_a = %load_a_scale scale_b = %load_b_scale : vector<128x512xf4E2M1FN>, vector<512x128xf4E2M1FN>, vector<128x128xf32>, vector<128x16xf8E8M0FNU>, vector<16x128xf8E8M0FNU> -> vector<128x128xf32>
+    %tdesc_cd = xegpu.create_nd_tdesc %arg2 : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32>
+    xegpu.store_nd %dpas_mx, %tdesc_cd[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32>
+    gpu.return
+  }
+  // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} dense<0.000000e+00> : vector<128x128xf32>
+  // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x512xf4E2M1FN> -> !xegpu.tensor_desc<128x512xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>>
+  // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<128x512xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>> -> vector<128x512xf4E2M1FN>
+  // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<512x128xf4E2M1FN> -> !xegpu.tensor_desc<512x128xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 16]>>
+  // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 16]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<512x128xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 16]>> -> vector<512x128xf4E2M1FN>
+  // CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf8E8M0FNU> -> !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>>
+  // CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>> -> vector<128x16xf8E8M0FNU>
+  // CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x128xf8E8M0FNU> -> !xegpu.tensor_desc<16x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 16]>>
+  // CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 16]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<16x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 16]>> -> vector<16x128xf8E8M0FNU>
+  // CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T1]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>, layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>, layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 16]>, layout_b_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 16]>, layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} :
+  // CHECK-SAME: vector<128x512xf4E2M1FN>, vector<512x128xf4E2M1FN>, vector<128x128xf32>, vector<128x16xf8E8M0FNU>, vector<16x128xf8E8M0FNU> -> vector<128x128xf32>
+  // CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+  // CHECK: xegpu.store_nd %[[T8]], %[[T9]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 5cec3cf54b655..d6437a75ea9df 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -1032,3 +1032,46 @@ func.func @dpas_mx_f8e5m2(%arg0: memref<8x32xf8E5M2>, %arg1: memref<32x16xf8E5M2
   return
 }
 }
+
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: func.func @dpas_mx_fp4
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xf4E2M1FN>, %[[ARG1:[0-9a-zA-Z]+]]: memref<64x16xf4E2M1FN>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xbf16>
+// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: memref<8x2xf8E8M0FNU>, %[[ARG4:[0-9a-zA-Z]+]]: memref<2x16xf8E8M0FNU>
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} dense<0.000000e+00> : vector<8x16xbf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<8x64xf4E2M1FN> -> !xegpu.tensor_desc<8x64xf4E2M1FN, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 4]>>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<64x16xf4E2M1FN> -> !xegpu.tensor_desc<64x16xf4E2M1FN, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 4]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<8x64xf4E2M1FN, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 4]>> -> vector<8x64xf4E2M1FN>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<64x16xf4E2M1FN, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>> -> vector<64x16xf4E2M1FN>
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG3]][{{.*}}] : memref<8x2xf8E8M0FNU> -> !xegpu.tensor_desc<8x2xf8E8M0FNU, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
+// CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<8x2xf8E8M0FNU, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<8x2xf8E8M0FNU>
+// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG4]][{{.*}}] : memref<2x16xf8E8M0FNU> -> !xegpu.tensor_desc<2x16xf8E8M0FNU, #xegpu.layout<lane_layout = [2, 16], lane_data = [1, 1]>>
+// CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<lane_layout = [2, 16], lane_data = [1, 1]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<2x16xf8E8M0FNU, #xegpu.layout<lane_layout = [2, 16], lane_data = [1, 1]>> -> vector<2x16xf8E8M0FNU>
+// CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T2]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
+// CHECK-SAME: {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 4]>, layout_a_scale = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>, layout_b_scale = #xegpu.layout<lane_layout = [2, 16], lane_data = [1, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: vector<8x64xf4E2M1FN>, vector<64x16xf4E2M1FN>, vector<8x16xbf16>, vector<8x2xf8E8M0FNU>, vector<2x16xf8E8M0FNU> -> vector<8x16xbf16>
+// CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+// CHECK: xegpu.store_nd %[[T8]], %[[T9]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+func.func @dpas_mx_fp4(%arg0: memref<8x64xf4E2M1FN>, %arg1: memref<64x16xf4E2M1FN>, %arg2: memref<8x16xbf16>,
+    %arg3: memref<8x2xf8E8M0FNU>, %arg4: memref<2x16xf8E8M0FNU>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<8x16xbf16>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x64xf4E2M1FN> -> !xegpu.tensor_desc<8x64xf4E2M1FN>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<64x16xf4E2M1FN> -> !xegpu.tensor_desc<64x16xf4E2M1FN>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x64xf4E2M1FN> -> vector<8x64xf4E2M1FN>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<64x16xf4E2M1FN> -> vector<64x16xf4E2M1FN>
+  %4 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<8x2xf8E8M0FNU> -> !xegpu.tensor_desc<8x2xf8E8M0FNU>
+  %5 = xegpu.load_nd %4  : !xegpu.tensor_desc<8x2xf8E8M0FNU> -> vector<8x2xf8E8M0FNU>
+  %6 = xegpu.create_nd_tdesc %arg4[%c0, %c0] : memref<2x16xf8E8M0FNU> -> !xegpu.tensor_desc<2x16xf8E8M0FNU>
+  %7 = xegpu.load_nd %6  : !xegpu.tensor_desc<2x16xf8E8M0FNU> -> vector<2x16xf8E8M0FNU>
+  %8 = xegpu.dpas_mx %2, %3, %cst scale_a = %5 scale_b = %7 : vector<8x64xf4E2M1FN>, vector<64x16xf4E2M1FN>, vector<8x16xbf16>, vector<8x2xf8E8M0FNU>, vector<2x16xf8E8M0FNU> -> vector<8x16xbf16>
+  %9 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xbf16> -> !xegpu.tensor_desc<8x16xbf16>
+  xegpu.store_nd %8, %9  : vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16>
+  return
+}
+}

>From de2ec0c6afd6fb0b40c8be61b0f097ed3ffaddff Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 24 Apr 2026 21:17:24 +0000
Subject: [PATCH 4/7] [XeGPU] Fix propagate-layout-inst-data test to use CRI
 chip for dpas_mx

The test file includes a dpas_mx operation test case which requires
the SubgroupScaledMatrixMultiplyAcc instruction. This instruction is
only registered in the CRI (Crescent Island) architecture, not PVC
(Ponte Vecchio).

Changed the RUN line from chip=pvc to chip=cri to match the instruction
requirements.

Fixes assertion: "Instruction not found in registry"

Co-Authored-By: Claude Sonnet 4.5 <noreply at anthropic.com>
---
 mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 891aac22eea2c..d1c3b663b52c6 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=inst" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=cri' -test-xegpu-propagate-layouts="layout-kind=inst" -split-input-file %s | FileCheck %s
 
 
 // CHECK-LABEL: func.func @load_store_no_array_len(

>From 99d4b1d01fbdf6c9a6efb9fe792fcf08eb7b0380 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 24 Apr 2026 22:36:44 +0000
Subject: [PATCH 5/7] fix issues

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    | 201 +++++++++++++++++-
 .../mlir/Dialect/XeGPU/uArch/uArchBase.h      |  17 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 196 +++++++++--------
 .../XeGPU/propagate-layout-inst-data.mlir     |  14 +-
 4 files changed, 313 insertions(+), 115 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 595965d414840..2c76019549c22 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -209,12 +209,65 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
 
   unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
   unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
+  bool isLaneLayoutRowMajorOrder() const override { return true; }
 
 protected:
   const unsigned packedFormatBitSizeA;
   const unsigned packedFormatBitSizeB;
 };
 
+struct SubgroupScaledMatrixMultiplyAcc : public Instruction,
+                                         public MMAInstructionInterface {
+  SubgroupScaledMatrixMultiplyAcc(unsigned packedFormatBitSizeA,
+                                  unsigned packedFormatBitSizeB,
+                                  unsigned scaleFactor)
+      : Instruction(InstructionKind::SubgroupScaledMatrixMultiplyAcc,
+                    InstructionScope::Subgroup),
+        packedFormatBitSizeA(packedFormatBitSizeA),
+        packedFormatBitSizeB(packedFormatBitSizeB), scaleFactor(scaleFactor) {}
+  static bool classof(const Instruction *B) {
+    return B->getInstructionKind() ==
+           InstructionKind::SubgroupScaledMatrixMultiplyAcc;
+  }
+  // Source:
+  // https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_subgroup_scaled_matrix_multiply_accumulate.asciidoc
+
+  // Override all virtuals from MatrixOpInterface
+  virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
+  getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
+  virtual llvm::SmallVector<Type, 8>
+  getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
+  virtual bool
+  checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
+                               std::pair<uint32_t, uint32_t> BShape,
+                               std::pair<uint32_t, uint32_t> CShape,
+                               std::pair<uint32_t, uint32_t> DShape, Type AType,
+                               Type BType, Type CType, Type DType) override;
+  virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
+                                   Type DType) override;
+  virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
+                        std::pair<uint32_t, uint32_t> BShape,
+                        std::pair<uint32_t, uint32_t> CShape,
+                        std::pair<uint32_t, uint32_t> DShape, Type AType,
+                        Type BType, Type CType, Type DType) override;
+  virtual llvm::SmallVector<uint32_t, 8>
+  getSupportedM(Type type) const override;
+  virtual llvm::SmallVector<uint32_t, 8>
+  getSupportedK(Type type) const override;
+  virtual llvm::SmallVector<uint32_t, 8>
+  getSupportedN(Type type) const override;
+
+  unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
+  unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
+  unsigned getScaleFactor() const { return scaleFactor; }
+  bool isLaneLayoutRowMajorOrder() const override { return true; }
+
+protected:
+  const unsigned packedFormatBitSizeA;
+  const unsigned packedFormatBitSizeB;
+  const unsigned scaleFactor;
+};
+
 struct SpirvLoadGatherInstruction : public LoadGatherInstructionInterface {
   int32_t getMaxLaneLoadSize(int32_t bitWidth) const override { return 16; }
 };
@@ -230,6 +283,7 @@ struct SpirvStoreScatterInstruction : public StoreScatterInstructionInterface {
 struct PVCuArch final : public Xe2Plus {
   static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
     static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
+    static const SubgroupScaledMatrixMultiplyAcc dpasMxInst{16, 32, 32};
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
@@ -282,14 +336,15 @@ struct BMGuArch : public Xe2Plus {
 struct CRIuArch : public Xe2Plus {
   static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
     static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
+    static const SubgroupScaledMatrixMultiplyAcc dpasMxInst{16, 32, 32};
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
     static const SpirvStoreScatterInstruction storeScatterInst;
     static const SpirvLoadGatherInstruction loadGatherInst;
-    static const Instruction *arr[] = {&dpasInst,         &loadNdInst,
-                                       &storeNdInst,      &prefetchNdInst,
-                                       &storeScatterInst, &loadGatherInst};
+    static const Instruction *arr[] = {
+        &dpasInst,       &dpasMxInst,       &loadNdInst,    &storeNdInst,
+        &prefetchNdInst, &storeScatterInst, &loadGatherInst};
     return arr;
   }
 
@@ -477,4 +532,144 @@ SubgroupMatrixMultiplyAcc::getSupportedN(Type type) const {
   return {16};
 }
 
+//===----------------------------------------------------------------------===//
+// SubgroupScaledMatrixMultiplyAcc implementations
+//===----------------------------------------------------------------------===//
+
+inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
+SubgroupScaledMatrixMultiplyAcc::getSupportedShapes(Type dataType,
+                                                    MMAOpndKind matrixType) {
+  auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
+                           const llvm::SmallVector<uint32_t, 8> &b)
+      -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
+    llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> result;
+    for (unsigned x : a) {
+      for (unsigned y : b) {
+        result.emplace_back(x, y);
+      }
+    }
+    return result;
+  };
+
+  // Avoid calling getSupportedK for C/D types (which are f32/bf16
+  // and not valid for the K-dimension bit-width calculation).
+  switch (matrixType) {
+  case MMAOpndKind::MatrixA:
+    return combineVectors(getSupportedM(dataType), getSupportedK(dataType));
+  case MMAOpndKind::MatrixB:
+    return combineVectors(getSupportedK(dataType), getSupportedN(dataType));
+  case MMAOpndKind::MatrixC:
+  case MMAOpndKind::MatrixD:
+    return combineVectors(getSupportedM(dataType), getSupportedN(dataType));
+  }
+  return {};
+}
+
+inline llvm::SmallVector<Type, 8>
+SubgroupScaledMatrixMultiplyAcc::getSupportedTypes(MLIRContext &context,
+                                                   MMAOpndKind matrixType) {
+  Type f8E4M3FNType = Float8E4M3FNType::get(&context);
+  Type f8E5M2Type = Float8E5M2Type::get(&context);
+  Type f4E2M1FNType = Float4E2M1FNType::get(&context);
+  Type bf16Type = BFloat16Type::get(&context);
+  Type f32Type = Float32Type::get(&context);
+
+  switch (matrixType) {
+  case MMAOpndKind::MatrixA:
+    return {f8E4M3FNType, f8E5M2Type, f4E2M1FNType};
+  case MMAOpndKind::MatrixB:
+    return {f8E4M3FNType, f8E5M2Type, f4E2M1FNType};
+  case MMAOpndKind::MatrixC:
+    return {bf16Type, f32Type};
+  case MMAOpndKind::MatrixD:
+    return {bf16Type, f32Type};
+  }
+  return {};
+}
+
+inline bool SubgroupScaledMatrixMultiplyAcc::checkSupportedTypes(Type AType,
+                                                                 Type BType,
+                                                                 Type CType,
+                                                                 Type DType) {
+  auto isSupportedLowPrecision = [](Type t) {
+    return t.isF8E4M3FN() || t.isF8E5M2() || llvm::isa<Float4E2M1FNType>(t);
+  };
+  auto isSupportedAccum = [](Type t) { return t.isF32() || t.isBF16(); };
+
+  if (!isSupportedLowPrecision(AType) || !isSupportedLowPrecision(BType)) {
+    LDBG() << "Unsupported scaled dpas: A and B must be FP8 or FP4 types.";
+    return false;
+  }
+
+  // A and B must have the same bit width for K dimension compatibility.
+  if (AType.getIntOrFloatBitWidth() != BType.getIntOrFloatBitWidth()) {
+    LDBG() << "Unsupported scaled dpas: A and B must have the same bit width.";
+    return false;
+  }
+
+  if (CType && !isSupportedAccum(CType)) {
+    LDBG() << "Unsupported scaled dpas: C must be f32 or bf16.";
+    return false;
+  }
+
+  if (!isSupportedAccum(DType)) {
+    LDBG() << "Unsupported scaled dpas: D must be f32 or bf16.";
+    return false;
+  }
+
+  return true;
+}
+
+inline bool SubgroupScaledMatrixMultiplyAcc::checkSupportedShapesAndTypes(
+    std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+    std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+    Type AType, Type BType, Type CType, Type DType) {
+  auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
+  auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
+  auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
+  auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
+  return llvm::is_contained(supportedAShapes, AShape) &&
+         llvm::is_contained(supportedBShapes, BShape) &&
+         llvm::is_contained(supportedCShapes, CShape) &&
+         llvm::is_contained(supportedDShapes, DShape) &&
+         checkSupportedTypes(AType, BType, CType, DType);
+}
+
+inline bool SubgroupScaledMatrixMultiplyAcc::validate(
+    std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+    std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+    Type AType, Type BType, Type CType, Type DType) {
+  return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
+                                      BType, CType, DType);
+}
+
+inline llvm::SmallVector<uint32_t, 8>
+SubgroupScaledMatrixMultiplyAcc::getSupportedM(Type type) const {
+  return {8};
+}
+
+inline llvm::SmallVector<uint32_t, 8>
+SubgroupScaledMatrixMultiplyAcc::getSupportedK(Type type) const {
+  assert(type.isIntOrFloat() && "Matrix type must be int or float");
+  auto bitWidth = type.getIntOrFloatBitWidth();
+  uint32_t kSize = 0;
+  switch (bitWidth) {
+  case 4:
+    kSize = 64; // FP4: scale K by 4 (base 16-bit K=16 -> 64)
+    break;
+  case 8:
+    kSize = 32; // FP8: scale K by 2 (base 16-bit K=16 -> 32)
+    break;
+  default:
+    llvm_unreachable("Scaled dpas only supports FP8 (8-bit) and FP4 (4-bit) "
+                     "types for A/B matrices");
+  }
+  return {kSize};
+}
+
+inline llvm::SmallVector<uint32_t, 8>
+SubgroupScaledMatrixMultiplyAcc::getSupportedN(Type type) const {
+  return {16};
+}
+
 #endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 0f9d052e11147..147a56a52c188 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -36,11 +36,14 @@ enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
 enum class InstructionKind {
   SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
                              // matrix multiply-add operation
-  Subgroup2DBlockStore,      // Subgroup-level 2D block write instruction
-  Subgroup2DBlockLoad,       // Subgroup-level 2D block load instruction
-  Subgroup2DBlockPrefetch,   // Subgroup-level 2D block prefetch instruction
-  StoreScatter,              // Lane-level store (scalar, vector)
-  LoadGather,                // Lane-level load (scalar, vector)
+  SubgroupScaledMatrixMultiplyAcc, // Scaled Matrix Multiply Accumulate is a
+                                   // DPAS with scaling factor applied to
+                                   // operand A or B before multiplication
+  Subgroup2DBlockStore,            // Subgroup-level 2D block write instruction
+  Subgroup2DBlockLoad,             // Subgroup-level 2D block load instruction
+  Subgroup2DBlockPrefetch, // Subgroup-level 2D block prefetch instruction
+  StoreScatter,            // Lane-level store (scalar, vector)
+  LoadGather,              // Lane-level load (scalar, vector)
   // @TODO: Add more instructions as needed
 };
 
@@ -61,6 +64,8 @@ struct Instruction {
     switch (instKind) {
     case InstructionKind::SubgroupMatrixMultiplyAcc:
       return "dpas";
+    case InstructionKind::SubgroupScaledMatrixMultiplyAcc:
+      return "dpas_mx";
     case InstructionKind::Subgroup2DBlockStore:
       return "store_nd";
     case InstructionKind::Subgroup2DBlockLoad:
@@ -246,7 +251,7 @@ struct MMAInstructionInterface {
   virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) const = 0;
   virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) const = 0;
   virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) const = 0;
-
+  virtual bool isLaneLayoutRowMajorOrder() const = 0;
   virtual ~MMAInstructionInterface() = default;
 };
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 737596daedb23..12b5c9d66b4ad 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1184,11 +1184,19 @@ getValidLayouts(ArrayRef<int64_t> wgShape, ArrayRef<int64_t> instData,
 static std::optional<std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
                                 SmallVector<int64_t>>>
 getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy,
-                       const xegpu::uArch::uArch *uArch) {
+                       const xegpu::uArch::uArch *uArch,
+                       bool isDpasMx = false) {
   const int subgroupSize = uArch->getSubgroupSize();
-  const auto *uArchInstruction =
-      dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
-          xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+
+  const xegpu::uArch::MMAInstructionInterface *uArchInstruction;
+  if (isDpasMx)
+    uArchInstruction = dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
+        uArch->getInstruction(
+            xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc));
+  else
+    uArchInstruction =
+        dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+            xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
 
   const unsigned dataALen = aTy.getShape().front();
   auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
@@ -1206,11 +1214,19 @@ getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy,
   if (maxALen == -1 || maxBLen == -1 || maxCLen == -1)
     return std::nullopt;
 
+  // For DPAS_MX, use getSupportedK to get the scaled K dimension
+  int kDimSize = subgroupSize;
+  if (isDpasMx) {
+    auto supportedKLen = uArchInstruction->getSupportedK(aTy.getElementType());
+    if (!supportedKLen.empty())
+      kDimSize = supportedKLen[0];
+  }
+
   SmallVector<int64_t> instDataA(aTy.getRank(), 1);
   instDataA[aTy.getRank() - 2] = maxALen;
-  instDataA[aTy.getRank() - 1] = subgroupSize;
+  instDataA[aTy.getRank() - 1] = kDimSize;
   SmallVector<int64_t> instDataB(bTy.getRank(), 1);
-  instDataB[bTy.getRank() - 2] = subgroupSize;
+  instDataB[bTy.getRank() - 2] = kDimSize;
   instDataB[bTy.getRank() - 1] = maxBLen;
   SmallVector<int64_t> instDataCD(cdTy.getRank(), 1);
   instDataCD[cdTy.getRank() - 2] = maxALen;
@@ -1220,9 +1236,9 @@ getDpasInstDataVectors(VectorType aTy, VectorType bTy, VectorType cdTy,
 
 /// Helper function to set up subgroup layouts for DPAS operands A, B, and C/D.
 /// Returns the three layouts if successful, nullopt otherwise.
-static std::optional<std::tuple<xegpu::DistributeLayoutAttr,
-                                xegpu::DistributeLayoutAttr,
-                                xegpu::DistributeLayoutAttr>>
+static std::optional<
+    std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
+               xegpu::DistributeLayoutAttr>>
 getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy,
                          VectorType bTy, VectorType cdTy,
                          xegpu::DistributeLayoutAttr consumerLayout, int numSg,
@@ -1237,8 +1253,7 @@ getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy,
 
   std::optional<LayoutRepresentation> consumerSgLayout = std::nullopt;
   if (consumerLayout && consumerLayout.isForWorkgroup()) {
-    SmallVector<int64_t> sgLayoutD =
-        consumerLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> sgLayoutD = consumerLayout.getEffectiveSgLayoutAsInt();
     consumerSgLayout = std::make_pair(sgLayoutD[0], sgLayoutD[1]);
   }
 
@@ -1252,7 +1267,7 @@ getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy,
   // Pick the best subgroup layout
   llvm::DenseSet<LayoutRepresentation> setA(layoutsA.begin(), layoutsA.end());
   llvm::DenseSet<LayoutRepresentation> setCD(layoutsCD.begin(),
-                                              layoutsCD.end());
+                                             layoutsCD.end());
   std::optional<LayoutRepresentation> bestPick;
   for (auto &sgLayout : layoutsB) {
     if (setA.contains(sgLayout) && setCD.contains(sgLayout)) {
@@ -1279,18 +1294,18 @@ getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy,
       static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
       static_cast<int>(cdTy.getShape()[1] / sgLayout[1])};
 
-  auto dpasALayout = xegpu::LayoutAttr::get(
-      context, DenseI32ArrayAttr::get(context, sgLayout),
-      DenseI32ArrayAttr::get(context, sgDataA), nullptr, nullptr, nullptr,
-      nullptr);
-  auto dpasBLayout = xegpu::LayoutAttr::get(
-      context, DenseI32ArrayAttr::get(context, sgLayout),
-      DenseI32ArrayAttr::get(context, sgDataB), nullptr, nullptr, nullptr,
-      nullptr);
-  auto dpasCDLayout = xegpu::LayoutAttr::get(
-      context, DenseI32ArrayAttr::get(context, sgLayout),
-      DenseI32ArrayAttr::get(context, sgDataCD), nullptr, nullptr, nullptr,
-      nullptr);
+  auto dpasALayout =
+      xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
+                             DenseI32ArrayAttr::get(context, sgDataA), nullptr,
+                             nullptr, nullptr, nullptr);
+  auto dpasBLayout =
+      xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
+                             DenseI32ArrayAttr::get(context, sgDataB), nullptr,
+                             nullptr, nullptr, nullptr);
+  auto dpasCDLayout =
+      xegpu::LayoutAttr::get(context, DenseI32ArrayAttr::get(context, sgLayout),
+                             DenseI32ArrayAttr::get(context, sgDataCD), nullptr,
+                             nullptr, nullptr, nullptr);
 
   return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout);
 }
@@ -1353,109 +1368,91 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
                          xegpu::DistributeLayoutAttr consumerLayout, int numSg,
                          const xegpu::uArch::uArch *uArch) {
   auto context = aTy.getContext();
-  const int subgroupSize = uArch->getSubgroupSize();
 
-  // Helper to create scale layout from parent layout
-  auto createScaleLayout = [&](VectorType parentTy, VectorType scaleTy,
-                               xegpu::DistributeLayoutAttr parentLayout,
+  // Helper to create scale layout from matrix layout
+  auto createScaleLayout = [&](VectorType matrixTy, VectorType scaleTy,
+                               xegpu::DistributeLayoutAttr matrixLayout,
                                bool isBScale) -> xegpu::DistributeLayoutAttr {
-    if (!scaleTy || !parentLayout)
+    if (!scaleTy || !matrixLayout)
       return nullptr;
 
-    // Calculate scaling factor by dividing parent shape by scale shape
-    ArrayRef<int64_t> parentShape = parentTy.getShape();
+    // Calculate scaling factor by dividing matrix shape by scale shape
+    ArrayRef<int64_t> matrixShape = matrixTy.getShape();
     ArrayRef<int64_t> scaleShape = scaleTy.getShape();
 
     // Scale shapes can be 1D or 2D, handle both cases
     if (scaleShape.empty())
       return nullptr;
 
-    int64_t scaleFactor = parentShape.back() / scaleShape.back();
-    int64_t rank = parentLayout.getRank();
+    auto uArchInstruction = dyn_cast<
+        xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(uArch->getInstruction(
+        xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc));
+    int64_t scaleFactor = uArchInstruction->getScaleFactor();
+
+    int64_t rank = matrixLayout.getRank();
     assert(rank == 2 && "dpas layouts must be two dimensions");
 
-    SmallVector<int64_t> sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> sgData = parentLayout.getEffectiveSgDataAsInt();
-    SmallVector<int64_t> instData = parentLayout.getEffectiveInstDataAsInt();
+    SmallVector<int64_t> sgLayout = matrixLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> sgData = matrixLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> instData = matrixLayout.getEffectiveInstDataAsInt();
     SmallVector<int64_t> laneLayout =
-        parentLayout.getEffectiveLaneLayoutAsInt();
-    SmallVector<int64_t> laneData = parentLayout.getEffectiveLaneDataAsInt();
-    auto order = parentLayout.getOrder();
+        matrixLayout.getEffectiveLaneLayoutAsInt();
+    SmallVector<int64_t> laneData = matrixLayout.getEffectiveLaneDataAsInt();
+    auto order = matrixLayout.getOrder();
+
+    SmallVector<int> scaleSgLayout(sgLayout.begin(), sgLayout.end());
+    SmallVector<int> scaleSgData(sgData.begin(), sgData.end());
 
-    // For subgroup layouts, compute sg_data based on scale shape / sg_layout
     if (!sgLayout.empty() && !sgData.empty()) {
-      // sg_data = scale_shape / sg_layout
-      if (scaleShape.size() >= 2) {
-        sgData[rank - 2] = scaleShape[rank - 2] / sgLayout[rank - 2];
-        sgData[rank - 1] = scaleShape[rank - 1] / sgLayout[rank - 1];
-      } else if (scaleShape.size() == 1) {
-        sgData[rank - 1] = scaleShape[0] / sgLayout[rank - 1];
-      }
+      scaleSgData[rank - 2] =
+          std::max<int64_t>(scaleShape[rank - 2] / sgLayout[rank - 2], 1);
+      scaleSgData[rank - 1] =
+          std::max<int64_t>(scaleShape[rank - 1] / sgLayout[rank - 1], 1);
     }
 
-    // For inst_data only layouts (no lane info), create a simple inst_data layout for scales
-    // The scale dimensions are much smaller, so we use the scale shape directly
-    if (!instData.empty() && laneLayout.empty() && laneData.empty()) {
-      // For inst_data layout, create simple inst_data for the scale
-      SmallVector<int64_t> scaleInstData(rank, 1);
-      if (scaleShape.size() >= 2) {
-        scaleInstData[rank - 2] = scaleShape[rank - 2];
-        scaleInstData[rank - 1] = scaleShape[rank - 1];
-      } else if (scaleShape.size() == 1) {
-        scaleInstData[rank - 1] = scaleShape[0];
-      }
-      return xegpu::LayoutAttr::get(
-          context, nullptr, nullptr,
-          DenseI32ArrayAttr::get(
-              context, SmallVector<int>(scaleInstData.begin(), scaleInstData.end())),
-          nullptr, nullptr, order);
+    SmallVector<int> scaleInstData(matrixShape.begin(), matrixShape.end());
+    if (!instData.empty()) {
+      if (isBScale)
+        scaleInstData[rank - 2] =
+            std::max<int64_t>(matrixShape[rank - 2] / scaleFactor, 1);
+      else
+        scaleInstData[rank - 1] =
+            std::max<int64_t>(matrixShape[rank - 1] / scaleFactor, 1);
     }
 
-    // Handle lane layout for subgroup/lane layout kinds
+    SmallVector<int> scaleLaneLayout(laneLayout.begin(), laneLayout.end());
+    SmallVector<int> scaleLaneData(laneData.begin(), laneData.end());
+
     if (!laneLayout.empty() && !laneData.empty()) {
-      // For scales, lane_layout should match the scale shape dimensions
-      // and lane_data should be [1, 1] since each lane holds one scale element
-      if (scaleShape.size() >= 2) {
-        laneLayout[rank - 2] = scaleShape[rank - 2];
-        laneLayout[rank - 1] = scaleShape[rank - 1];
-      } else if (scaleShape.size() == 1) {
-        laneLayout[rank - 2] = 1;
-        laneLayout[rank - 1] = scaleShape[0];
-      }
-      laneData[rank - 2] = 1;
-      laneData[rank - 1] = 1;
+      bool order = uArchInstruction->isLaneLayoutRowMajorOrder();
+      if (isBScale ^ order)
+        std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
+
+      scaleLaneData[rank - 2] =
+          std::max<int64_t>(scaleShape[rank - 2] / laneLayout[rank - 2], 1);
+      scaleLaneData[rank - 1] =
+          std::max<int64_t>(scaleShape[rank - 1] / laneLayout[rank - 1], 1);
     }
-
     return xegpu::LayoutAttr::get(
         context,
-        sgLayout.empty()
-            ? nullptr
-            : DenseI32ArrayAttr::get(
-                  context, SmallVector<int>(sgLayout.begin(), sgLayout.end())),
-        sgData.empty()
-            ? nullptr
-            : DenseI32ArrayAttr::get(
-                  context, SmallVector<int>(sgData.begin(), sgData.end())),
-        instData.empty()
+        scaleSgLayout.empty() ? nullptr
+                              : DenseI32ArrayAttr::get(context, scaleSgLayout),
+        scaleSgData.empty() ? nullptr
+                            : DenseI32ArrayAttr::get(context, scaleSgData),
+        scaleInstData.empty() ? nullptr
+                              : DenseI32ArrayAttr::get(context, scaleInstData),
+        scaleLaneLayout.empty()
             ? nullptr
-            : DenseI32ArrayAttr::get(
-                  context, SmallVector<int>(instData.begin(), instData.end())),
-        laneLayout.empty() ? nullptr
-                           : DenseI32ArrayAttr::get(
-                                 context, SmallVector<int>(laneLayout.begin(),
-                                                           laneLayout.end())),
-        laneData.empty()
-            ? nullptr
-            : DenseI32ArrayAttr::get(
-                  context, SmallVector<int>(laneData.begin(), laneData.end())),
+            : DenseI32ArrayAttr::get(context, scaleLaneLayout),
+        scaleLaneData.empty() ? nullptr
+                              : DenseI32ArrayAttr::get(context, scaleLaneData),
         order);
   };
-
   if (layoutKind == xegpu::LayoutKind::Subgroup) {
     assert(numSg > 0 &&
            "Number of subgroups must be provided for sg layout creation.");
-    auto dpasLayouts =
-        getupDpasSubgroupLayouts(context, aTy, bTy, cdTy, consumerLayout, numSg, uArch);
+    auto dpasLayouts = getupDpasSubgroupLayouts(context, aTy, bTy, cdTy,
+                                                consumerLayout, numSg, uArch);
     if (!dpasLayouts)
       return std::nullopt;
 
@@ -1474,7 +1471,8 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
     return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
                            bScaleLayout);
   } else if (layoutKind == xegpu::LayoutKind::InstData) {
-    auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
+    auto instDataVecs =
+        getDpasInstDataVectors(aTy, bTy, cdTy, uArch, /*isDpasMx=*/true);
     if (!instDataVecs)
       return std::nullopt;
     auto [instDataA, instDataB, instDataCD] = *instDataVecs;
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index d1c3b663b52c6..12e975eb1d78d 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -402,12 +402,12 @@ gpu.module @test{
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<16x64xf8E5M2>, %[[ARG1:[0-9a-zA-Z]+]]: memref<64x32xf8E5M2>, %[[ARG2:[0-9a-zA-Z]+]]: memref<16x32xbf16>
 // CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: memref<16x2xf8E8M0FNU>, %[[ARG4:[0-9a-zA-Z]+]]: memref<2x32xf8E8M0FNU>
 // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<16x32xbf16>
-// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<16x64xf8E5M2> -> !xegpu.tensor_desc<16x64xf8E5M2, #xegpu.layout<inst_data = [8, 16]>>
-// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<64x32xf8E5M2> -> !xegpu.tensor_desc<64x32xf8E5M2, #xegpu.layout<inst_data = [16, 16]>>
-// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> :
-// CHECK-SAME: !xegpu.tensor_desc<16x64xf8E5M2, #xegpu.layout<inst_data = [8, 16]>> -> vector<16x64xf8E5M2>
-// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<inst_data = [16, 16]>}> :
-// CHECK-SAME: !xegpu.tensor_desc<64x32xf8E5M2, #xegpu.layout<inst_data = [16, 16]>> -> vector<64x32xf8E5M2>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<16x64xf8E5M2> -> !xegpu.tensor_desc<16x64xf8E5M2, #xegpu.layout<inst_data = [8, 32]>>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<64x32xf8E5M2> -> !xegpu.tensor_desc<64x32xf8E5M2, #xegpu.layout<inst_data = [32, 16]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<inst_data = [8, 32]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<16x64xf8E5M2, #xegpu.layout<inst_data = [8, 32]>> -> vector<16x64xf8E5M2>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<inst_data = [32, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<64x32xf8E5M2, #xegpu.layout<inst_data = [32, 16]>> -> vector<64x32xf8E5M2>
 // CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG3]][{{.*}}] : memref<16x2xf8E8M0FNU> -> !xegpu.tensor_desc<16x2xf8E8M0FNU, #xegpu.layout<inst_data = [16, 2]>>
 // CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<inst_data = [16, 2]>}> :
 // CHECK-SAME: !xegpu.tensor_desc<16x2xf8E8M0FNU, #xegpu.layout<inst_data = [16, 2]>> -> vector<16x2xf8E8M0FNU>
@@ -415,7 +415,7 @@ gpu.module @test{
 // CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<inst_data = [2, 32]>}> :
 // CHECK-SAME: !xegpu.tensor_desc<2x32xf8E8M0FNU, #xegpu.layout<inst_data = [2, 32]>> -> vector<2x32xf8E8M0FNU>
 // CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T2]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
-// CHECK-SAME: {layout_a = #xegpu.layout<inst_data = [8, 16]>, layout_a_scale = #xegpu.layout<inst_data = [16, 2]>, layout_b = #xegpu.layout<inst_data = [16, 16]>, layout_b_scale = #xegpu.layout<inst_data = [2, 32]>, layout_cd = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: {layout_a = #xegpu.layout<inst_data = [8, 32]>, layout_a_scale = #xegpu.layout<inst_data = [16, 2]>, layout_b = #xegpu.layout<inst_data = [32, 16]>, layout_b_scale = #xegpu.layout<inst_data = [2, 32]>, layout_cd = #xegpu.layout<inst_data = [8, 16]>} :
 // CHECK-SAME: vector<16x64xf8E5M2>, vector<64x32xf8E5M2>, vector<16x32xbf16>, vector<16x2xf8E8M0FNU>, vector<2x32xf8E8M0FNU> -> vector<16x32xbf16>
 // CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<16x32xbf16> -> !xegpu.tensor_desc<16x32xbf16, #xegpu.layout<inst_data = [8, 16]>>
 // CHECK: xegpu.store_nd %[[T8]], %[[T9]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x32xbf16>, !xegpu.tensor_desc<16x32xbf16, #xegpu.layout<inst_data = [8, 16]>>

>From 94e78c74971b5aeba8d182ec32b4db7fb7fc14ba Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 25 Apr 2026 05:04:25 +0000
Subject: [PATCH 6/7] add tests

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |  61 ++++++-----
 .../XeGPU/propagate-layout-inst-data.mlir     |  57 ++++++++--
 .../XeGPU/propagate-layout-subgroup.mlir      | 102 +++++++++---------
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  17 ++-
 4 files changed, 144 insertions(+), 93 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 12b5c9d66b4ad..2cb8cc159c149 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1284,11 +1284,10 @@ getupDpasSubgroupLayouts(mlir::MLIRContext *context, VectorType aTy,
 
   SmallVector<int> sgLayout = {static_cast<int>(bestPick->first),
                                static_cast<int>(bestPick->second)};
-  SmallVector<int> sgDataA = {
-      static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
-      static_cast<int>(aTy.getShape()[1] / sgLayout[1])};
+  SmallVector<int> sgDataA = {static_cast<int>(aTy.getShape()[0] / sgLayout[0]),
+                              static_cast<int>(aTy.getShape()[1])};
   SmallVector<int> sgDataB = {
-      static_cast<int>(bTy.getShape()[0] / sgLayout[0]),
+      static_cast<int>(bTy.getShape()[0]),
       static_cast<int>(bTy.getShape()[1] / sgLayout[1])};
   SmallVector<int> sgDataCD = {
       static_cast<int>(cdTy.getShape()[0] / sgLayout[0]),
@@ -1387,7 +1386,6 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
     auto uArchInstruction = dyn_cast<
         xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(uArch->getInstruction(
         xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc));
-    int64_t scaleFactor = uArchInstruction->getScaleFactor();
 
     int64_t rank = matrixLayout.getRank();
     assert(rank == 2 && "dpas layouts must be two dimensions");
@@ -1400,38 +1398,49 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
     SmallVector<int64_t> laneData = matrixLayout.getEffectiveLaneDataAsInt();
     auto order = matrixLayout.getOrder();
 
-    SmallVector<int> scaleSgLayout(sgLayout.begin(), sgLayout.end());
-    SmallVector<int> scaleSgData(sgData.begin(), sgData.end());
-
+    SmallVector<int> scaleSgLayout;
+    SmallVector<int> scaleSgData;
     if (!sgLayout.empty() && !sgData.empty()) {
-      scaleSgData[rank - 2] =
-          std::max<int64_t>(scaleShape[rank - 2] / sgLayout[rank - 2], 1);
-      scaleSgData[rank - 1] =
-          std::max<int64_t>(scaleShape[rank - 1] / sgLayout[rank - 1], 1);
+      scaleSgLayout.assign(sgLayout.begin(), sgLayout.end());
+      scaleSgData.assign(sgData.begin(), sgData.end());
+      scaleSgData[rank - 2] = std::max<int64_t>(
+          scaleShape[rank - 2] / (matrixShape[rank - 2] / sgData[rank - 2]), 1);
+      scaleSgData[rank - 1] = std::max<int64_t>(
+          scaleShape[rank - 1] / (matrixShape[rank - 1] / sgData[rank - 1]), 1);
     }
 
-    SmallVector<int> scaleInstData(matrixShape.begin(), matrixShape.end());
+    // For DPAS_MX scales: if matrix has inst_data, scale needs adjusted
+    // inst_data Scale inst_data is derived from matrix inst_data divided by
+    // scale factor
+    SmallVector<int> scaleInstData;
     if (!instData.empty()) {
+      scaleInstData.assign(instData.begin(), instData.end());
       if (isBScale)
-        scaleInstData[rank - 2] =
-            std::max<int64_t>(matrixShape[rank - 2] / scaleFactor, 1);
+        scaleInstData[rank - 2] = std::max<int64_t>(
+            scaleShape[rank - 2] / (matrixShape[rank - 2] / instData[rank - 2]),
+            1);
       else
-        scaleInstData[rank - 1] =
-            std::max<int64_t>(matrixShape[rank - 1] / scaleFactor, 1);
+        scaleInstData[rank - 1] = std::max<int64_t>(
+            scaleShape[rank - 1] / (matrixShape[rank - 1] / instData[rank - 1]),
+            1);
     }
 
-    SmallVector<int> scaleLaneLayout(laneLayout.begin(), laneLayout.end());
-    SmallVector<int> scaleLaneData(laneData.begin(), laneData.end());
-
+    SmallVector<int> scaleLaneLayout;
+    SmallVector<int> scaleLaneData;
     if (!laneLayout.empty() && !laneData.empty()) {
+
+      scaleLaneLayout.assign(laneLayout.begin(), laneLayout.end());
+      scaleLaneData.assign(laneData.begin(), laneData.end());
       bool order = uArchInstruction->isLaneLayoutRowMajorOrder();
-      if (isBScale ^ order)
+      if (isBScale ^ order) {
         std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
-
-      scaleLaneData[rank - 2] =
-          std::max<int64_t>(scaleShape[rank - 2] / laneLayout[rank - 2], 1);
-      scaleLaneData[rank - 1] =
-          std::max<int64_t>(scaleShape[rank - 1] / laneLayout[rank - 1], 1);
+        scaleLaneLayout[rank - 2] =
+            std::min<int64_t>(scaleShape[rank - 2], scaleLaneLayout[rank - 2]);
+      }
+      scaleLaneData[rank - 2] = std::max<int64_t>(
+          scaleShape[rank - 2] / scaleLaneLayout[rank - 2], 1);
+      scaleLaneData[rank - 1] = std::max<int64_t>(
+          scaleShape[rank - 1] / scaleLaneLayout[rank - 1], 1);
     }
     return xegpu::LayoutAttr::get(
         context,
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index 12e975eb1d78d..dd23a946e081b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -408,14 +408,14 @@ gpu.module @test{
 // CHECK-SAME: !xegpu.tensor_desc<16x64xf8E5M2, #xegpu.layout<inst_data = [8, 32]>> -> vector<16x64xf8E5M2>
 // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<inst_data = [32, 16]>}> :
 // CHECK-SAME: !xegpu.tensor_desc<64x32xf8E5M2, #xegpu.layout<inst_data = [32, 16]>> -> vector<64x32xf8E5M2>
-// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG3]][{{.*}}] : memref<16x2xf8E8M0FNU> -> !xegpu.tensor_desc<16x2xf8E8M0FNU, #xegpu.layout<inst_data = [16, 2]>>
-// CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<inst_data = [16, 2]>}> :
-// CHECK-SAME: !xegpu.tensor_desc<16x2xf8E8M0FNU, #xegpu.layout<inst_data = [16, 2]>> -> vector<16x2xf8E8M0FNU>
-// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG4]][{{.*}}] : memref<2x32xf8E8M0FNU> -> !xegpu.tensor_desc<2x32xf8E8M0FNU, #xegpu.layout<inst_data = [2, 32]>>
-// CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<inst_data = [2, 32]>}> :
-// CHECK-SAME: !xegpu.tensor_desc<2x32xf8E8M0FNU, #xegpu.layout<inst_data = [2, 32]>> -> vector<2x32xf8E8M0FNU>
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG3]][{{.*}}] : memref<16x2xf8E8M0FNU> -> !xegpu.tensor_desc<16x2xf8E8M0FNU, #xegpu.layout<inst_data = [8, 1]>>
+// CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<inst_data = [8, 1]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<16x2xf8E8M0FNU, #xegpu.layout<inst_data = [8, 1]>> -> vector<16x2xf8E8M0FNU>
+// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG4]][{{.*}}] : memref<2x32xf8E8M0FNU> -> !xegpu.tensor_desc<2x32xf8E8M0FNU, #xegpu.layout<inst_data = [1, 16]>>
+// CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<inst_data = [1, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<2x32xf8E8M0FNU, #xegpu.layout<inst_data = [1, 16]>> -> vector<2x32xf8E8M0FNU>
 // CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T2]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
-// CHECK-SAME: {layout_a = #xegpu.layout<inst_data = [8, 32]>, layout_a_scale = #xegpu.layout<inst_data = [16, 2]>, layout_b = #xegpu.layout<inst_data = [32, 16]>, layout_b_scale = #xegpu.layout<inst_data = [2, 32]>, layout_cd = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: {layout_a = #xegpu.layout<inst_data = [8, 32]>, layout_a_scale = #xegpu.layout<inst_data = [8, 1]>, layout_b = #xegpu.layout<inst_data = [32, 16]>, layout_b_scale = #xegpu.layout<inst_data = [1, 16]>, layout_cd = #xegpu.layout<inst_data = [8, 16]>} :
 // CHECK-SAME: vector<16x64xf8E5M2>, vector<64x32xf8E5M2>, vector<16x32xbf16>, vector<16x2xf8E8M0FNU>, vector<2x32xf8E8M0FNU> -> vector<16x32xbf16>
 // CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<16x32xbf16> -> !xegpu.tensor_desc<16x32xbf16, #xegpu.layout<inst_data = [8, 16]>>
 // CHECK: xegpu.store_nd %[[T8]], %[[T9]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x32xbf16>, !xegpu.tensor_desc<16x32xbf16, #xegpu.layout<inst_data = [8, 16]>>
@@ -438,3 +438,46 @@ func.func @dpas_mx_f8e5m2(%arg0: memref<16x64xf8E5M2>, %arg1: memref<64x32xf8E5M
   return
 }
 }
+
+// -----
+// CHECK-LABEL: func.func @dpas_mx_f4e2m1
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<16x128xf4E2M1FN>, %[[ARG1:[0-9a-zA-Z]+]]: memref<128x32xf4E2M1FN>, %[[ARG2:[0-9a-zA-Z]+]]: memref<16x32xbf16>
+// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: memref<16x4xf8E8M0FNU>, %[[ARG4:[0-9a-zA-Z]+]]: memref<4x32xf8E8M0FNU>
+// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} dense<0.000000e+00> : vector<16x32xbf16>
+// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}] : memref<16x128xf4E2M1FN> -> !xegpu.tensor_desc<16x128xf4E2M1FN, #xegpu.layout<inst_data = [8, 64]>>
+// CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}] : memref<128x32xf4E2M1FN> -> !xegpu.tensor_desc<128x32xf4E2M1FN, #xegpu.layout<inst_data = [64, 16]>>
+// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<inst_data = [8, 64]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<16x128xf4E2M1FN, #xegpu.layout<inst_data = [8, 64]>> -> vector<16x128xf4E2M1FN>
+// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<inst_data = [64, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<128x32xf4E2M1FN, #xegpu.layout<inst_data = [64, 16]>> -> vector<128x32xf4E2M1FN>
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG3]][{{.*}}] : memref<16x4xf8E8M0FNU> -> !xegpu.tensor_desc<16x4xf8E8M0FNU, #xegpu.layout<inst_data = [8, 2]>>
+// CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<inst_data = [8, 2]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<16x4xf8E8M0FNU, #xegpu.layout<inst_data = [8, 2]>> -> vector<16x4xf8E8M0FNU>
+// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG4]][{{.*}}] : memref<4x32xf8E8M0FNU> -> !xegpu.tensor_desc<4x32xf8E8M0FNU, #xegpu.layout<inst_data = [2, 16]>>
+// CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<inst_data = [2, 16]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<4x32xf8E8M0FNU, #xegpu.layout<inst_data = [2, 16]>> -> vector<4x32xf8E8M0FNU>
+// CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T2]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
+// CHECK-SAME: {layout_a = #xegpu.layout<inst_data = [8, 64]>, layout_a_scale = #xegpu.layout<inst_data = [8, 2]>, layout_b = #xegpu.layout<inst_data = [64, 16]>, layout_b_scale = #xegpu.layout<inst_data = [2, 16]>, layout_cd = #xegpu.layout<inst_data = [8, 16]>} :
+// CHECK-SAME: vector<16x128xf4E2M1FN>, vector<128x32xf4E2M1FN>, vector<16x32xbf16>, vector<16x4xf8E8M0FNU>, vector<4x32xf8E8M0FNU> -> vector<16x32xbf16>
+// CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<16x32xbf16> -> !xegpu.tensor_desc<16x32xbf16, #xegpu.layout<inst_data = [8, 16]>>
+// CHECK: xegpu.store_nd %[[T8]], %[[T9]] <{layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x32xbf16>, !xegpu.tensor_desc<16x32xbf16, #xegpu.layout<inst_data = [8, 16]>>
+gpu.module @test {
+func.func @dpas_mx_f4e2m1(%arg0: memref<16x128xf4E2M1FN>, %arg1: memref<128x32xf4E2M1FN>, %arg2: memref<16x32xbf16>,
+    %arg3: memref<16x4xf8E8M0FNU>, %arg4: memref<4x32xf8E8M0FNU>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.000000e+00> : vector<16x32xbf16>
+  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<16x128xf4E2M1FN> -> !xegpu.tensor_desc<16x128xf4E2M1FN>
+  %1 = xegpu.create_nd_tdesc %arg1[%c0, %c0] : memref<128x32xf4E2M1FN> -> !xegpu.tensor_desc<128x32xf4E2M1FN>
+  %2 = xegpu.load_nd %0  : !xegpu.tensor_desc<16x128xf4E2M1FN> -> vector<16x128xf4E2M1FN>
+  %3 = xegpu.load_nd %1  : !xegpu.tensor_desc<128x32xf4E2M1FN> -> vector<128x32xf4E2M1FN>
+  %4 = xegpu.create_nd_tdesc %arg3[%c0, %c0] : memref<16x4xf8E8M0FNU> -> !xegpu.tensor_desc<16x4xf8E8M0FNU>
+  %5 = xegpu.load_nd %4  : !xegpu.tensor_desc<16x4xf8E8M0FNU> -> vector<16x4xf8E8M0FNU>
+  %6 = xegpu.create_nd_tdesc %arg4[%c0, %c0] : memref<4x32xf8E8M0FNU> -> !xegpu.tensor_desc<4x32xf8E8M0FNU>
+  %7 = xegpu.load_nd %6  : !xegpu.tensor_desc<4x32xf8E8M0FNU> -> vector<4x32xf8E8M0FNU>
+  %8 = xegpu.dpas_mx %2, %3, %cst scale_a = %5 scale_b = %7 : vector<16x128xf4E2M1FN>, vector<128x32xf4E2M1FN>, vector<16x32xbf16>, vector<16x4xf8E8M0FNU>, vector<4x32xf8E8M0FNU> -> vector<16x32xbf16>
+  %9 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<16x32xbf16> -> !xegpu.tensor_desc<16x32xbf16>
+  xegpu.store_nd %8, %9  : vector<16x32xbf16>, !xegpu.tensor_desc<16x32xbf16>
+  return
+}
+}
+
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 03cf4210700f1..21900c0c470d6 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=subgroup" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=cri' -test-xegpu-propagate-layouts="layout-kind=subgroup" -split-input-file %s | FileCheck %s
 
 gpu.module @test {
   // CHECK-LABEL: store_nd
@@ -88,21 +88,21 @@ gpu.module @test {
   gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>, %d: memref<128x128xf32>) kernel attributes
       {known_block_size = array<i32: 1, 64, 16>} {
   // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[A_MEMREF]] : memref<128x128xf16> ->
-  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
 
   // CHECK: %[[A_LOADED:.*]] = xegpu.load_nd %[[TDESC_A]]
-  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}>
-  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf16>
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>}>
+  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
 
   // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[B_MEMREF]] : memref<128x128xf16> ->
-  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+  // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
 
-  // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}>
-  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf16>
+  // CHECK: %[[B_LOADED:.*]] = xegpu.load_nd %[[TDESC_B]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>}>
+  // CHECK-SAME: : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
 
   // CHECK: %[[DPAS_RES:.*]] = xegpu.dpas %[[A_LOADED]], %[[B_LOADED]]
-  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>,
-  // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>,
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>,
+  // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>,
   // CHECK-SAME: layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} :
   // CHECK-SAME: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
 
@@ -243,19 +243,19 @@ gpu.module @test {
     // CHECK: %2 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (vector<128x128xf32>) {
     // CHECK-NEXT: xegpu.create_nd_tdesc %{{.*}} : memref<2048x8192xf16> ->
     // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>,
-    // CHECK-SAME: #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>>
+    // CHECK-SAME: #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 128]>>
 
-    // CHECK-NEXT: xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}>
+    // CHECK-NEXT: xegpu.load_nd %{{.*}} <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 128]>}>
 
     // CHECK-NEXT: xegpu.create_nd_tdesc %{{.*}} : memref<8192x4096xf16> ->
     // CHECK-SAME: !xegpu.tensor_desc<128x128xf16, #xegpu.block_tdesc_attr<boundary_check = false>,
-    // CHECK-SAME: #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>>
+    // CHECK-SAME: #xegpu.layout<sg_layout = [2, 4], sg_data = [128, 32]>>
 
-    // CHECK-NEXT: xegpu.load_nd %6[%arg3, %block_id_y] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}>
+    // CHECK-NEXT: xegpu.load_nd %6[%arg3, %block_id_y] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [128, 32]>}>
 
     // CHECK-NEXT: xegpu.dpas %{{.*}} {
-    // CHECK-SAME: layout_a = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>,
-    // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>,
+    // CHECK-SAME: layout_a = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 128]>,
+    // CHECK-SAME: layout_b = #xegpu.layout<sg_layout = [2, 4], sg_data = [128, 32]>,
     // CHECK-SAME: layout_cd = #xegpu.layout<sg_layout = [2, 4], sg_data = [64, 32]>}
     // CHECK-SAME: : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
 
@@ -355,6 +355,24 @@ gpu.module @test {
 
 gpu.module @test {
   // CHECK-LABEL: dpas_mx
+  // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>} dense<0.000000e+00> : vector<128x256xbf16>
+  // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x512xf8E5M2> -> !xegpu.tensor_desc<128x512xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 512]>>
+  // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 512]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<128x512xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 512]>> -> vector<128x512xf8E5M2>
+  // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<512x256xf8E5M2> -> !xegpu.tensor_desc<512x256xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [512, 32]>>
+  // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [512, 32]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<512x256xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [512, 32]>> -> vector<512x256xf8E5M2>
+  // CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf8E8M0FNU> -> !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+  // CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x16xf8E8M0FNU>
+  // CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x256xf8E8M0FNU> -> !xegpu.tensor_desc<16x256xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>>
+  // CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<16x256xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>> -> vector<16x256xf8E8M0FNU>
+  // CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T1]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 512]>, layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>, layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [512, 32]>, layout_b_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>, layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>} :
+  // CHECK-SAME: vector<128x512xf8E5M2>, vector<512x256xf8E5M2>, vector<128x256xbf16>, vector<128x16xf8E8M0FNU>, vector<16x256xf8E8M0FNU> -> vector<128x256xbf16>
+  // CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x256xbf16> -> !xegpu.tensor_desc<128x256xbf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>>
+  // CHECK: xegpu.store_nd %[[T8]], %[[T9]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>}> : vector<128x256xbf16>, !xegpu.tensor_desc<128x256xbf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>>
   gpu.func @dpas_mx(%arg0: memref<128x512xf8E5M2>, %arg1: memref<512x256xf8E5M2>, %arg2: memref<128x256xbf16>,
       %arg3: memref<128x16xf8E8M0FNU>, %arg4: memref<16x256xf8E8M0FNU>) kernel attributes
       {known_block_size = array<i32: 1, 64, 16>} {
@@ -372,30 +390,30 @@ gpu.module @test {
     xegpu.store_nd %dpas_mx, %tdesc_cd[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>}> : vector<128x256xbf16>, !xegpu.tensor_desc<128x256xbf16>
     gpu.return
   }
-  // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>} dense<0.000000e+00> : vector<128x256xbf16>
-  // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x512xf8E5M2> -> !xegpu.tensor_desc<128x512xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>>
-  // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>}> :
-  // CHECK-SAME: !xegpu.tensor_desc<128x512xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>> -> vector<128x512xf8E5M2>
-  // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<512x256xf8E5M2> -> !xegpu.tensor_desc<512x256xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 32]>>
-  // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 32]>}> :
-  // CHECK-SAME: !xegpu.tensor_desc<512x256xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 32]>> -> vector<512x256xf8E5M2>
-  // CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf8E8M0FNU> -> !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>>
-  // CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>}> :
-  // CHECK-SAME: !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>> -> vector<128x16xf8E8M0FNU>
-  // CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x256xf8E8M0FNU> -> !xegpu.tensor_desc<16x256xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 32]>>
-  // CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 32]>}> :
-  // CHECK-SAME: !xegpu.tensor_desc<16x256xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 32]>> -> vector<16x256xf8E8M0FNU>
-  // CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T1]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
-  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>, layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>, layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 32]>, layout_b_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 32]>, layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>} :
-  // CHECK-SAME: vector<128x512xf8E5M2>, vector<512x256xf8E5M2>, vector<128x256xbf16>, vector<128x16xf8E8M0FNU>, vector<16x256xf8E8M0FNU> -> vector<128x256xbf16>
-  // CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x256xbf16> -> !xegpu.tensor_desc<128x256xbf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>>
-  // CHECK: xegpu.store_nd %[[T8]], %[[T9]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>}> : vector<128x256xbf16>, !xegpu.tensor_desc<128x256xbf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 32]>>
 }
 
 // -----
 
 gpu.module @test {
   // CHECK-LABEL: dpas_mx_fp4
+  // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} dense<0.000000e+00> : vector<128x128xf32>
+  // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x512xf4E2M1FN> -> !xegpu.tensor_desc<128x512xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 512]>>
+  // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 512]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<128x512xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 512]>> -> vector<128x512xf4E2M1FN>
+  // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<512x128xf4E2M1FN> -> !xegpu.tensor_desc<512x128xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [512, 16]>>
+  // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [512, 16]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<512x128xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [512, 16]>> -> vector<512x128xf4E2M1FN>
+  // CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf8E8M0FNU> -> !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+  // CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x16xf8E8M0FNU>
+  // CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x128xf8E8M0FNU> -> !xegpu.tensor_desc<16x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+  // CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> :
+  // CHECK-SAME: !xegpu.tensor_desc<16x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<16x128xf8E8M0FNU>
+  // CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T1]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
+  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 512]>, layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>, layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [512, 16]>, layout_b_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>, layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} :
+  // CHECK-SAME: vector<128x512xf4E2M1FN>, vector<512x128xf4E2M1FN>, vector<128x128xf32>, vector<128x16xf8E8M0FNU>, vector<16x128xf8E8M0FNU> -> vector<128x128xf32>
+  // CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+  // CHECK: xegpu.store_nd %[[T8]], %[[T9]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
   gpu.func @dpas_mx_fp4(%arg0: memref<128x512xf4E2M1FN>, %arg1: memref<512x128xf4E2M1FN>, %arg2: memref<128x128xf32>,
       %arg3: memref<128x16xf8E8M0FNU>, %arg4: memref<16x128xf8E8M0FNU>) kernel attributes
       {known_block_size = array<i32: 1, 64, 16>} {
@@ -413,22 +431,4 @@ gpu.module @test {
     xegpu.store_nd %dpas_mx, %tdesc_cd[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32>
     gpu.return
   }
-  // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} dense<0.000000e+00> : vector<128x128xf32>
-  // CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x512xf4E2M1FN> -> !xegpu.tensor_desc<128x512xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>>
-  // CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>}> :
-  // CHECK-SAME: !xegpu.tensor_desc<128x512xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>> -> vector<128x512xf4E2M1FN>
-  // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<512x128xf4E2M1FN> -> !xegpu.tensor_desc<512x128xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 16]>>
-  // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 16]>}> :
-  // CHECK-SAME: !xegpu.tensor_desc<512x128xf4E2M1FN, #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 16]>> -> vector<512x128xf4E2M1FN>
-  // CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x16xf8E8M0FNU> -> !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>>
-  // CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>}> :
-  // CHECK-SAME: !xegpu.tensor_desc<128x16xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>> -> vector<128x16xf8E8M0FNU>
-  // CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<16x128xf8E8M0FNU> -> !xegpu.tensor_desc<16x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 16]>>
-  // CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 16]>}> :
-  // CHECK-SAME: !xegpu.tensor_desc<16x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 16]>> -> vector<16x128xf8E8M0FNU>
-  // CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T1]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
-  // CHECK-SAME: {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 64]>, layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 2]>, layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [64, 16]>, layout_b_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [2, 16]>, layout_cd = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} :
-  // CHECK-SAME: vector<128x512xf4E2M1FN>, vector<512x128xf4E2M1FN>, vector<128x128xf32>, vector<128x16xf8E8M0FNU>, vector<16x128xf8E8M0FNU> -> vector<128x128xf32>
-  // CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %{{.*}} : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
-  // CHECK: xegpu.store_nd %[[T8]], %[[T9]][0, 0] <{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}> : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
 }
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index d6437a75ea9df..f6cef5ec0245e 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -xevm-attach-target='chip=pvc' -test-xegpu-propagate-layouts="layout-kind=lane" -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -xevm-attach-target='chip=cri' -test-xegpu-propagate-layouts="layout-kind=lane" -split-input-file %s | FileCheck %s
 
 gpu.module @test {
 // CHECK-LABEL: func.func @dpas_f16(
@@ -1034,7 +1034,6 @@ func.func @dpas_mx_f8e5m2(%arg0: memref<8x32xf8E5M2>, %arg1: memref<32x16xf8E5M2
 }
 
 // -----
-
 gpu.module @test {
 // CHECK-LABEL: func.func @dpas_mx_fp4
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x64xf4E2M1FN>, %[[ARG1:[0-9a-zA-Z]+]]: memref<64x16xf4E2M1FN>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xbf16>
@@ -1046,14 +1045,14 @@ gpu.module @test {
 // CHECK-SAME: !xegpu.tensor_desc<8x64xf4E2M1FN, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 4]>> -> vector<8x64xf4E2M1FN>
 // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T1]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>}> :
 // CHECK-SAME: !xegpu.tensor_desc<64x16xf4E2M1FN, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>> -> vector<64x16xf4E2M1FN>
-// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG3]][{{.*}}] : memref<8x2xf8E8M0FNU> -> !xegpu.tensor_desc<8x2xf8E8M0FNU, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
-// CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>}> :
-// CHECK-SAME: !xegpu.tensor_desc<8x2xf8E8M0FNU, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<8x2xf8E8M0FNU>
-// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG4]][{{.*}}] : memref<2x16xf8E8M0FNU> -> !xegpu.tensor_desc<2x16xf8E8M0FNU, #xegpu.layout<lane_layout = [2, 16], lane_data = [1, 1]>>
-// CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<lane_layout = [2, 16], lane_data = [1, 1]>}> :
-// CHECK-SAME: !xegpu.tensor_desc<2x16xf8E8M0FNU, #xegpu.layout<lane_layout = [2, 16], lane_data = [1, 1]>> -> vector<2x16xf8E8M0FNU>
+// CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[ARG3]][{{.*}}] : memref<8x2xf8E8M0FNU> -> !xegpu.tensor_desc<8x2xf8E8M0FNU, #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2]>>
+// CHECK: %[[T5:.*]] = xegpu.load_nd %[[T4]] <{layout = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<8x2xf8E8M0FNU, #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2]>> -> vector<8x2xf8E8M0FNU>
+// CHECK: %[[T6:.*]] = xegpu.create_nd_tdesc %[[ARG4]][{{.*}}] : memref<2x16xf8E8M0FNU> -> !xegpu.tensor_desc<2x16xf8E8M0FNU, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+// CHECK: %[[T7:.*]] = xegpu.load_nd %[[T6]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>}> :
+// CHECK-SAME: !xegpu.tensor_desc<2x16xf8E8M0FNU, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<2x16xf8E8M0FNU>
 // CHECK: %[[T8:.*]] = xegpu.dpas_mx %[[T2]], %[[T3]], %[[CST]] scale_a = %[[T5]] scale_b = %[[T7]]
-// CHECK-SAME: {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 4]>, layout_a_scale = #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>, layout_b_scale = #xegpu.layout<lane_layout = [2, 16], lane_data = [1, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+// CHECK-SAME: {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 4]>, layout_a_scale = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>, layout_b_scale = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
 // CHECK-SAME: vector<8x64xf4E2M1FN>, vector<64x16xf4E2M1FN>, vector<8x16xbf16>, vector<8x2xf8E8M0FNU>, vector<2x16xf8E8M0FNU> -> vector<8x16xbf16>
 // CHECK: %[[T9:.*]] = xegpu.create_nd_tdesc %[[ARG2]][{{.*}}] : memref<8x16xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
 // CHECK: xegpu.store_nd %[[T8]], %[[T9]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<8x16xbf16>, !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>

>From 7a765756699636cea2412c864bec774b0078ec92 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 25 Apr 2026 05:48:20 +0000
Subject: [PATCH 7/7] cleanup

---
 .../mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h    |  10 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 228 +++++++++---------
 2 files changed, 120 insertions(+), 118 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 2c76019549c22..eca7ffdb890ec 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -219,12 +219,11 @@ struct SubgroupMatrixMultiplyAcc : public Instruction,
 struct SubgroupScaledMatrixMultiplyAcc : public Instruction,
                                          public MMAInstructionInterface {
   SubgroupScaledMatrixMultiplyAcc(unsigned packedFormatBitSizeA,
-                                  unsigned packedFormatBitSizeB,
-                                  unsigned scaleFactor)
+                                  unsigned packedFormatBitSizeB)
       : Instruction(InstructionKind::SubgroupScaledMatrixMultiplyAcc,
                     InstructionScope::Subgroup),
         packedFormatBitSizeA(packedFormatBitSizeA),
-        packedFormatBitSizeB(packedFormatBitSizeB), scaleFactor(scaleFactor) {}
+        packedFormatBitSizeB(packedFormatBitSizeB) {}
   static bool classof(const Instruction *B) {
     return B->getInstructionKind() ==
            InstructionKind::SubgroupScaledMatrixMultiplyAcc;
@@ -259,13 +258,11 @@ struct SubgroupScaledMatrixMultiplyAcc : public Instruction,
 
   unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
   unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
-  unsigned getScaleFactor() const { return scaleFactor; }
   bool isLaneLayoutRowMajorOrder() const override { return true; }
 
 protected:
   const unsigned packedFormatBitSizeA;
   const unsigned packedFormatBitSizeB;
-  const unsigned scaleFactor;
 };
 
 struct SpirvLoadGatherInstruction : public LoadGatherInstructionInterface {
@@ -283,7 +280,6 @@ struct SpirvStoreScatterInstruction : public StoreScatterInstructionInterface {
 struct PVCuArch final : public Xe2Plus {
   static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
     static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
-    static const SubgroupScaledMatrixMultiplyAcc dpasMxInst{16, 32, 32};
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
@@ -336,7 +332,7 @@ struct BMGuArch : public Xe2Plus {
 struct CRIuArch : public Xe2Plus {
   static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
     static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
-    static const SubgroupScaledMatrixMultiplyAcc dpasMxInst{16, 32, 32};
+    static const SubgroupScaledMatrixMultiplyAcc dpasMxInst{16, 32};
     static const Subgroup2DBlockLoadInstruction loadNdInst;
     static const Subgroup2DBlockStoreInstruction storeNdInst;
     static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 2cb8cc159c149..518f7b181ff5d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1353,6 +1353,99 @@ xegpu::setupDpasLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
   return std::nullopt;
 }
 
+/// Helper to create a scale layout derived from a matrix operand layout.
+/// The scale layout is computed by mapping each dimension of the matrix layout
+/// to the corresponding scale tensor dimension using the ratio between the
+/// matrix and scale shapes.
+static xegpu::DistributeLayoutAttr
+createScaleLayout(mlir::MLIRContext *context, VectorType matrixTy,
+                  VectorType scaleTy, xegpu::DistributeLayoutAttr matrixLayout,
+                  bool isBScale, const xegpu::uArch::uArch *uArch) {
+  if (!scaleTy || !matrixLayout)
+    return nullptr;
+
+  // Calculate scaling factor by dividing matrix shape by scale shape
+  ArrayRef<int64_t> matrixShape = matrixTy.getShape();
+  ArrayRef<int64_t> scaleShape = scaleTy.getShape();
+
+  // Scale shapes can be 1D or 2D, handle both cases
+  if (scaleShape.empty())
+    return nullptr;
+
+  auto uArchInstruction =
+      dyn_cast<xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(
+          uArch->getInstruction(
+              xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc));
+
+  int64_t rank = matrixLayout.getRank();
+  assert(rank == 2 && "dpas layouts must be two dimensions");
+
+  SmallVector<int64_t> sgLayout = matrixLayout.getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> sgData = matrixLayout.getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instData = matrixLayout.getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneLayout = matrixLayout.getEffectiveLaneLayoutAsInt();
+  SmallVector<int64_t> laneData = matrixLayout.getEffectiveLaneDataAsInt();
+  auto order = matrixLayout.getOrder();
+
+  SmallVector<int> scaleSgLayout;
+  SmallVector<int> scaleSgData;
+  if (!sgLayout.empty() && !sgData.empty()) {
+    scaleSgLayout.assign(sgLayout.begin(), sgLayout.end());
+    scaleSgData.assign(sgData.begin(), sgData.end());
+    scaleSgData[rank - 2] = std::max<int64_t>(
+        scaleShape[rank - 2] / (matrixShape[rank - 2] / sgData[rank - 2]), 1);
+    scaleSgData[rank - 1] = std::max<int64_t>(
+        scaleShape[rank - 1] / (matrixShape[rank - 1] / sgData[rank - 1]), 1);
+  }
+
+  // For DPAS_MX scales: if matrix has inst_data, scale needs adjusted
+  // inst_data. Scale inst_data is derived from matrix inst_data divided by
+  // scale factor.
+  SmallVector<int> scaleInstData;
+  if (!instData.empty()) {
+    scaleInstData.assign(instData.begin(), instData.end());
+    if (isBScale)
+      scaleInstData[rank - 2] = std::max<int64_t>(
+          scaleShape[rank - 2] / (matrixShape[rank - 2] / instData[rank - 2]),
+          1);
+    else
+      scaleInstData[rank - 1] = std::max<int64_t>(
+          scaleShape[rank - 1] / (matrixShape[rank - 1] / instData[rank - 1]),
+          1);
+  }
+
+  SmallVector<int> scaleLaneLayout;
+  SmallVector<int> scaleLaneData;
+  if (!laneLayout.empty() && !laneData.empty()) {
+    scaleLaneLayout.assign(laneLayout.begin(), laneLayout.end());
+    scaleLaneData.assign(laneData.begin(), laneData.end());
+    bool isRowMajor = uArchInstruction->isLaneLayoutRowMajorOrder();
+    if (isBScale ^ isRowMajor) {
+      std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
+      scaleLaneLayout[rank - 2] =
+          std::min<int64_t>(scaleShape[rank - 2], scaleLaneLayout[rank - 2]);
+    }
+    scaleLaneData[rank - 2] =
+        std::max<int64_t>(scaleShape[rank - 2] / scaleLaneLayout[rank - 2], 1);
+    scaleLaneData[rank - 1] =
+        std::max<int64_t>(scaleShape[rank - 1] / scaleLaneLayout[rank - 1], 1);
+  }
+  return xegpu::LayoutAttr::get(
+      context,
+      scaleSgLayout.empty() ? nullptr
+                            : DenseI32ArrayAttr::get(context, scaleSgLayout),
+      scaleSgData.empty() ? nullptr
+                          : DenseI32ArrayAttr::get(context, scaleSgData),
+      scaleInstData.empty() ? nullptr
+                            : DenseI32ArrayAttr::get(context, scaleInstData),
+      scaleLaneLayout.empty()
+          ? nullptr
+          : DenseI32ArrayAttr::get(context, scaleLaneLayout),
+      scaleLaneData.empty() ? nullptr
+                            : DenseI32ArrayAttr::get(context, scaleLaneData),
+      order);
+}
+
 /// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
 /// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
 /// creation.
@@ -1368,95 +1461,6 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
                          const xegpu::uArch::uArch *uArch) {
   auto context = aTy.getContext();
 
-  // Helper to create scale layout from matrix layout
-  auto createScaleLayout = [&](VectorType matrixTy, VectorType scaleTy,
-                               xegpu::DistributeLayoutAttr matrixLayout,
-                               bool isBScale) -> xegpu::DistributeLayoutAttr {
-    if (!scaleTy || !matrixLayout)
-      return nullptr;
-
-    // Calculate scaling factor by dividing matrix shape by scale shape
-    ArrayRef<int64_t> matrixShape = matrixTy.getShape();
-    ArrayRef<int64_t> scaleShape = scaleTy.getShape();
-
-    // Scale shapes can be 1D or 2D, handle both cases
-    if (scaleShape.empty())
-      return nullptr;
-
-    auto uArchInstruction = dyn_cast<
-        xegpu::uArch::SubgroupScaledMatrixMultiplyAcc>(uArch->getInstruction(
-        xegpu::uArch::InstructionKind::SubgroupScaledMatrixMultiplyAcc));
-
-    int64_t rank = matrixLayout.getRank();
-    assert(rank == 2 && "dpas layouts must be two dimensions");
-
-    SmallVector<int64_t> sgLayout = matrixLayout.getEffectiveSgLayoutAsInt();
-    SmallVector<int64_t> sgData = matrixLayout.getEffectiveSgDataAsInt();
-    SmallVector<int64_t> instData = matrixLayout.getEffectiveInstDataAsInt();
-    SmallVector<int64_t> laneLayout =
-        matrixLayout.getEffectiveLaneLayoutAsInt();
-    SmallVector<int64_t> laneData = matrixLayout.getEffectiveLaneDataAsInt();
-    auto order = matrixLayout.getOrder();
-
-    SmallVector<int> scaleSgLayout;
-    SmallVector<int> scaleSgData;
-    if (!sgLayout.empty() && !sgData.empty()) {
-      scaleSgLayout.assign(sgLayout.begin(), sgLayout.end());
-      scaleSgData.assign(sgData.begin(), sgData.end());
-      scaleSgData[rank - 2] = std::max<int64_t>(
-          scaleShape[rank - 2] / (matrixShape[rank - 2] / sgData[rank - 2]), 1);
-      scaleSgData[rank - 1] = std::max<int64_t>(
-          scaleShape[rank - 1] / (matrixShape[rank - 1] / sgData[rank - 1]), 1);
-    }
-
-    // For DPAS_MX scales: if matrix has inst_data, scale needs adjusted
-    // inst_data Scale inst_data is derived from matrix inst_data divided by
-    // scale factor
-    SmallVector<int> scaleInstData;
-    if (!instData.empty()) {
-      scaleInstData.assign(instData.begin(), instData.end());
-      if (isBScale)
-        scaleInstData[rank - 2] = std::max<int64_t>(
-            scaleShape[rank - 2] / (matrixShape[rank - 2] / instData[rank - 2]),
-            1);
-      else
-        scaleInstData[rank - 1] = std::max<int64_t>(
-            scaleShape[rank - 1] / (matrixShape[rank - 1] / instData[rank - 1]),
-            1);
-    }
-
-    SmallVector<int> scaleLaneLayout;
-    SmallVector<int> scaleLaneData;
-    if (!laneLayout.empty() && !laneData.empty()) {
-
-      scaleLaneLayout.assign(laneLayout.begin(), laneLayout.end());
-      scaleLaneData.assign(laneData.begin(), laneData.end());
-      bool order = uArchInstruction->isLaneLayoutRowMajorOrder();
-      if (isBScale ^ order) {
-        std::swap(scaleLaneLayout[rank - 2], scaleLaneLayout[rank - 1]);
-        scaleLaneLayout[rank - 2] =
-            std::min<int64_t>(scaleShape[rank - 2], scaleLaneLayout[rank - 2]);
-      }
-      scaleLaneData[rank - 2] = std::max<int64_t>(
-          scaleShape[rank - 2] / scaleLaneLayout[rank - 2], 1);
-      scaleLaneData[rank - 1] = std::max<int64_t>(
-          scaleShape[rank - 1] / scaleLaneLayout[rank - 1], 1);
-    }
-    return xegpu::LayoutAttr::get(
-        context,
-        scaleSgLayout.empty() ? nullptr
-                              : DenseI32ArrayAttr::get(context, scaleSgLayout),
-        scaleSgData.empty() ? nullptr
-                            : DenseI32ArrayAttr::get(context, scaleSgData),
-        scaleInstData.empty() ? nullptr
-                              : DenseI32ArrayAttr::get(context, scaleInstData),
-        scaleLaneLayout.empty()
-            ? nullptr
-            : DenseI32ArrayAttr::get(context, scaleLaneLayout),
-        scaleLaneData.empty() ? nullptr
-                              : DenseI32ArrayAttr::get(context, scaleLaneData),
-        order);
-  };
   if (layoutKind == xegpu::LayoutKind::Subgroup) {
     assert(numSg > 0 &&
            "Number of subgroups must be provided for sg layout creation.");
@@ -1468,14 +1472,14 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
     auto [dpasALayout, dpasBLayout, dpasCDLayout] = *dpasLayouts;
 
     // Create scale layouts
-    auto aScaleLayout =
-        aScaleTy.has_value()
-            ? createScaleLayout(aTy, *aScaleTy, dpasALayout, false)
-            : nullptr;
-    auto bScaleLayout =
-        bScaleTy.has_value()
-            ? createScaleLayout(bTy, *bScaleTy, dpasBLayout, true)
-            : nullptr;
+    auto aScaleLayout = aScaleTy.has_value()
+                            ? createScaleLayout(context, aTy, *aScaleTy,
+                                                dpasALayout, false, uArch)
+                            : nullptr;
+    auto bScaleLayout = bScaleTy.has_value()
+                            ? createScaleLayout(context, bTy, *bScaleTy,
+                                                dpasBLayout, true, uArch)
+                            : nullptr;
 
     return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
                            bScaleLayout);
@@ -1494,14 +1498,14 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
         context, SmallVector<int>(instDataCD.begin(), instDataCD.end()));
 
     // Create scale layouts
-    auto aScaleLayout =
-        aScaleTy.has_value()
-            ? createScaleLayout(aTy, *aScaleTy, dpasALayout, false)
-            : nullptr;
-    auto bScaleLayout =
-        bScaleTy.has_value()
-            ? createScaleLayout(bTy, *bScaleTy, dpasBLayout, true)
-            : nullptr;
+    auto aScaleLayout = aScaleTy.has_value()
+                            ? createScaleLayout(context, aTy, *aScaleTy,
+                                                dpasALayout, false, uArch)
+                            : nullptr;
+    auto bScaleLayout = bScaleTy.has_value()
+                            ? createScaleLayout(context, bTy, *bScaleTy,
+                                                dpasBLayout, true, uArch)
+                            : nullptr;
 
     return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
                            bScaleLayout);
@@ -1516,12 +1520,14 @@ xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
     auto cdLayout = getDefaultLaneLayout2DBlockIo(cdTy, uArch);
 
     // Create scale layouts
-    auto aScaleLayout = aScaleTy.has_value()
-                            ? createScaleLayout(aTy, *aScaleTy, aLayout, false)
-                            : nullptr;
-    auto bScaleLayout = bScaleTy.has_value()
-                            ? createScaleLayout(bTy, *bScaleTy, bLayout, true)
-                            : nullptr;
+    auto aScaleLayout =
+        aScaleTy.has_value()
+            ? createScaleLayout(context, aTy, *aScaleTy, aLayout, false, uArch)
+            : nullptr;
+    auto bScaleLayout =
+        bScaleTy.has_value()
+            ? createScaleLayout(context, bTy, *bScaleTy, bLayout, true, uArch)
+            : nullptr;
 
     return std::make_tuple(aLayout, bLayout, cdLayout, aScaleLayout,
                            bScaleLayout);



More information about the Mlir-commits mailing list