[Mlir-commits] [mlir] [MLIR][XeGPU] Add wg-to-sg distirbution for dpasmx, bitcast, interleave, and deinterleave (PR #194985)

Jianhui Li llvmlistbot at llvm.org
Wed May 6 09:18:41 PDT 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/194985

>From 618b5fe929e70aada84a6baa8fafde7bb4f1697e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 23 Apr 2026 23:33:41 +0000
Subject: [PATCH 1/8] [mlir][XeGPU] Add DpasMx op definition and layout support

This patch extends the DpasMx operation with scale layout attributes and
implements layout setup and propagation for MXFP (microscaling floating point)
operations.

Op definition changes (XeGPUOps.td):
- Add layout_a_scale and layout_b_scale attributes to DpasMx op
- Remove restrictive AllElementTypesMatch trait to allow different types for
  A and B operands with scale factors

Layout support (XeGPULayoutImpl.cpp/h):
- setupDpasMxLayout: Creates anchor layouts for all DpasMx operands (A, B, C/D,
  scale_a, scale_b) across different layout kinds (Subgroup, InstData, Lane)
- Derives scale layouts from parent matrix layouts by adjusting dimensions based
  on scaling factors (K/32 for scales)

Layout propagation (XeGPUPropagateLayout.cpp):
- visitDpasMxOp: Propagates layout attributes from DpasMx op to its operands
  during dataflow analysis

This infrastructure enables proper layout tracking for mixed-precision matrix
operations with separate scale factors during workgroup-to-subgroup distribution.

Co-Authored-By: Claude Sonnet 4.5 <noreply at anthropic.com>
---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |   5 +-
 .../XeGPU/Transforms/XeGPULayoutImpl.h        |  18 ++-
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 151 ++++++++++++++++++
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  39 +++++
 4 files changed, 205 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 31fe93d209a6d..8b6763e234091 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1554,7 +1554,6 @@ def XeGPU_TruncfOp
 }
 
 def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
-                                          AllElementTypesMatch<["a", "b"]>,
                                           AnchorLayoutInterface]> {
   let summary = "It performs scaled mma computation";
 
@@ -1601,7 +1600,9 @@ def XeGPU_DpasMxOp : XeGPU_Op<"dpas_mx", [Pure, AttrSizedOperandSegments,
                           VectorOfRankAndType<[1, 2], [F8E8M0FNU]>]>>:$scale_b,
       OptionalAttr<DistributeLayoutAttr>:$layout_a,
       OptionalAttr<DistributeLayoutAttr>:$layout_b,
-      OptionalAttr<DistributeLayoutAttr>:$layout_cd);
+      OptionalAttr<DistributeLayoutAttr>:$layout_cd,
+      OptionalAttr<DistributeLayoutAttr>:$layout_a_scale,
+      OptionalAttr<DistributeLayoutAttr>:$layout_b_scale);
   let results = (outs XeGPU_DpasResType:$result);
   let extraClassDeclaration = [{
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
index 83eb939cf1bec..c68e7334434d6 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h
@@ -39,12 +39,6 @@ LogicalResult propagateLayouts(OpBuilder &builder, Operation *target,
 
 LogicalResult resolveLayoutConflicts(Operation *target);
 
-/// [to-be-deprecated] Set the DistributeLayoutAttr for each OpOperand and
-/// OpResult of of the given operation. If the operation contains regions, it is
-/// also applied recursively to the contained operations operation.
-/// TODO: To be replaced by recoverTemporaryLayouts()
-void recoverTemporaryLayoutsDeprecated(Operation *op);
-
 /// Attach layout attributes to all vector-type operands of operations within
 /// the given operation's nested region. Reports an error if any vector operand
 /// lacks a layout attribute.
@@ -199,6 +193,18 @@ setupDpasLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
                 VectorType cdTy, DistributeLayoutAttr consumerLayout, int numSg,
                 const uArch::uArch *uArch);
 
+/// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
+/// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
+/// creation. A_scale and B_scale are optional.
+std::optional<std::tuple<DistributeLayoutAttr, DistributeLayoutAttr,
+                         DistributeLayoutAttr, DistributeLayoutAttr,
+                         DistributeLayoutAttr>>
+setupDpasMxLayout(LayoutKind layoutKind, VectorType aTy, VectorType bTy,
+                  VectorType cdTy, std::optional<VectorType> aScaleTy,
+                  std::optional<VectorType> bScaleTy,
+                  DistributeLayoutAttr consumerLayout, int numSg,
+                  const uArch::uArch *uArch);
+
 /// Gets the expected layout for a given consumer operand. This will check if
 /// the owning operation of the consumer operand is one of the special layout
 /// users and determine the expected layout accordingly.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7d48315eec6ff..8f07ddf09e9a4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -1426,3 +1426,154 @@ xegpu::DistributeLayoutAttr xegpu::getConsumerLayoutAt(OpOperand &operand) {
   // the operand.
   return xegpu::getDistributeLayoutAttr(operand.get());
 }
+/// Sets up the anchor layouts for dpas_mx operands (A, B, C/D, A_scale, and
+/// B_scale). The numSg and consumerLayout (optional) are only used by sg layout
+/// creation.
+std::optional<
+    std::tuple<xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
+               xegpu::DistributeLayoutAttr, xegpu::DistributeLayoutAttr,
+               xegpu::DistributeLayoutAttr>>
+xegpu::setupDpasMxLayout(xegpu::LayoutKind layoutKind, VectorType aTy,
+                         VectorType bTy, VectorType cdTy,
+                         std::optional<VectorType> aScaleTy,
+                         std::optional<VectorType> bScaleTy,
+                         xegpu::DistributeLayoutAttr consumerLayout, int numSg,
+                         const xegpu::uArch::uArch *uArch) {
+  auto context = aTy.getContext();
+  const int subgroupSize = uArch->getSubgroupSize();
+
+  // Helper to create scale layout from parent layout
+  auto createScaleLayout = [&](VectorType parentTy, VectorType scaleTy,
+                               xegpu::DistributeLayoutAttr parentLayout,
+                               bool isBScale) -> xegpu::DistributeLayoutAttr {
+    if (!scaleTy || !parentLayout)
+      return nullptr;
+
+    // Calculate scaling factor by dividing parent shape by scale shape
+    ArrayRef<int64_t> parentShape = parentTy.getShape();
+    ArrayRef<int64_t> scaleShape = scaleTy.getShape();
+    int64_t scaleFactor = parentShape.back() / scaleShape.back();
+    int64_t rank = parentLayout.getRank();
+    assert(rank == 2 && "dpas layouts must be two dimensions");
+
+    SmallVector<int64_t> sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> sgData = parentLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> instData = parentLayout.getEffectiveInstDataAsInt();
+    SmallVector<int64_t> laneLayout =
+        parentLayout.getEffectiveLaneLayoutAsInt();
+    SmallVector<int64_t> laneData = parentLayout.getEffectiveLaneDataAsInt();
+    auto order = parentLayout.getOrder();
+
+    // Divide last dimension by scaling factor
+    if (!sgData.empty())
+      sgData.back() = sgData.back() / scaleFactor;
+    if (!instData.empty())
+      instData.back() = instData.back() / scaleFactor;
+
+    if (isBScale) {
+      // For B scale: lane_layout = [min(subgroupSize, scaleFactor), 1]
+      // lane_data = [1, number of scale elements]
+      laneLayout[rank - 2] =
+          std::min(static_cast<int64_t>(subgroupSize), scaleShape[rank-2]);
+      laneLayout[rank - 1] = 1;
+      laneData[rank - 2] = 1;
+      laneData[rank - 1] = scaleShape.back();
+    } else {
+      laneData.back() = scaleShape.back();
+    }
+
+    return xegpu::LayoutAttr::get(
+        context,
+        sgLayout.empty()
+            ? nullptr
+            : DenseI32ArrayAttr::get(
+                  context, SmallVector<int>(sgLayout.begin(), sgLayout.end())),
+        sgData.empty()
+            ? nullptr
+            : DenseI32ArrayAttr::get(
+                  context, SmallVector<int>(sgData.begin(), sgData.end())),
+        instData.empty()
+            ? nullptr
+            : DenseI32ArrayAttr::get(
+                  context, SmallVector<int>(instData.begin(), instData.end())),
+        laneLayout.empty() ? nullptr
+                           : DenseI32ArrayAttr::get(
+                                 context, SmallVector<int>(laneLayout.begin(),
+                                                           laneLayout.end())),
+        laneData.empty()
+            ? nullptr
+            : DenseI32ArrayAttr::get(
+                  context, SmallVector<int>(laneData.begin(), laneData.end())),
+        order);
+  };
+
+  if (layoutKind == xegpu::LayoutKind::Subgroup) {
+    assert(numSg > 0 &&
+           "Number of subgroups must be provided for sg layout creation.");
+    auto dpasLayouts =
+        getupDpasSubgroupLayouts(context, aTy, bTy, cdTy, consumerLayout, numSg, uArch);
+    if (!dpasLayouts)
+      return std::nullopt;
+
+    auto [dpasALayout, dpasBLayout, dpasCDLayout] = *dpasLayouts;
+
+    // Create scale layouts
+    auto aScaleLayout =
+        aScaleTy.has_value()
+            ? createScaleLayout(aTy, *aScaleTy, dpasALayout, false)
+            : nullptr;
+    auto bScaleLayout =
+        bScaleTy.has_value()
+            ? createScaleLayout(bTy, *bScaleTy, dpasBLayout, true)
+            : nullptr;
+
+    return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
+                           bScaleLayout);
+  } else if (layoutKind == xegpu::LayoutKind::InstData) {
+    auto instDataVecs = getDpasInstDataVectors(aTy, bTy, cdTy, uArch);
+    if (!instDataVecs)
+      return std::nullopt;
+    auto [instDataA, instDataB, instDataCD] = *instDataVecs;
+
+    auto dpasALayout = xegpu::LayoutAttr::get(
+        context, SmallVector<int>(instDataA.begin(), instDataA.end()));
+    auto dpasBLayout = xegpu::LayoutAttr::get(
+        context, SmallVector<int>(instDataB.begin(), instDataB.end()));
+    auto dpasCDLayout = xegpu::LayoutAttr::get(
+        context, SmallVector<int>(instDataCD.begin(), instDataCD.end()));
+
+    // Create scale layouts
+    auto aScaleLayout =
+        aScaleTy.has_value()
+            ? createScaleLayout(aTy, *aScaleTy, dpasALayout, false)
+            : nullptr;
+    auto bScaleLayout =
+        bScaleTy.has_value()
+            ? createScaleLayout(bTy, *bScaleTy, dpasBLayout, true)
+            : nullptr;
+
+    return std::make_tuple(dpasALayout, dpasBLayout, dpasCDLayout, aScaleLayout,
+                           bScaleLayout);
+  } else if (layoutKind == xegpu::LayoutKind::Lane) {
+    const auto *uArchInstruction =
+        dyn_cast<xegpu::uArch::SubgroupMatrixMultiplyAcc>(uArch->getInstruction(
+            xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc));
+    auto aLayout = getDefaultLaneLayout2DBlockIo(
+        aTy, uArch, uArchInstruction->getPackedFormatBitSizeA());
+    auto bLayout = getDefaultLaneLayout2DBlockIo(
+        bTy, uArch, uArchInstruction->getPackedFormatBitSizeB(), true);
+    auto cdLayout = getDefaultLaneLayout2DBlockIo(cdTy, uArch);
+
+    // Create scale layouts
+    auto aScaleLayout = aScaleTy.has_value()
+                            ? createScaleLayout(aTy, *aScaleTy, aLayout, false)
+                            : nullptr;
+    auto bScaleLayout = bScaleTy.has_value()
+                            ? createScaleLayout(bTy, *bScaleTy, bLayout, true)
+                            : nullptr;
+
+    return std::make_tuple(aLayout, bLayout, cdLayout, aScaleLayout,
+                           bScaleLayout);
+  }
+  return std::nullopt;
+}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 43998ed41f7aa..f1709e95d34a9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -316,6 +316,9 @@ class LayoutInfoPropagation
   void visitDpasOp(xegpu::DpasOp dpas, ArrayRef<LayoutInfoLattice *> operands,
                    ArrayRef<const LayoutInfoLattice *> results);
 
+  void visitDpasMxOp(xegpu::DpasMxOp dpasMx, ArrayRef<LayoutInfoLattice *> operands,
+                     ArrayRef<const LayoutInfoLattice *> results);
+
   void visitStoreNdOp(xegpu::StoreNdOp store,
                       ArrayRef<LayoutInfoLattice *> operands,
                       ArrayRef<const LayoutInfoLattice *> results);
@@ -426,6 +429,9 @@ LogicalResult LayoutInfoPropagation::visitOperation(
   TypeSwitch<Operation *>(op)
       .Case(
           [&](xegpu::DpasOp dpasOp) { visitDpasOp(dpasOp, operands, results); })
+      .Case([&](xegpu::DpasMxOp dpasMxOp) {
+        visitDpasMxOp(dpasMxOp, operands, results);
+      })
       .Case([&](xegpu::StoreNdOp storeNdOp) {
         visitStoreNdOp(storeNdOp, operands, results);
       })
@@ -804,6 +810,39 @@ void LayoutInfoPropagation::visitDpasOp(
     std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
 
     dpas.setLayoutAAttr(requiredALayout);
+
+/// Propagate layout for DpasMxOp operands using the layout attributes.
+/// DpasMxOp has operands: a, b, acc (optional), scale_a (optional), scale_b (optional)
+void LayoutInfoPropagation::visitDpasMxOp(
+    xegpu::DpasMxOp dpasMx, ArrayRef<LayoutInfoLattice *> operands,
+    ArrayRef<const LayoutInfoLattice *> results) {
+
+  // Get the layout attributes from the operation
+  xegpu::DistributeLayoutAttr layoutA = dpasMx.getLayoutAAttr();
+  xegpu::DistributeLayoutAttr layoutB = dpasMx.getLayoutBAttr();
+  xegpu::DistributeLayoutAttr layoutCD = dpasMx.getLayoutCdAttr();
+  xegpu::DistributeLayoutAttr layoutAScale = dpasMx.getLayoutAScaleAttr();
+  xegpu::DistributeLayoutAttr layoutBScale = dpasMx.getLayoutBScaleAttr();
+
+  // Propagate layouts to operands based on their positions:
+  // operands[0] = a, operands[1] = b, operands[2] = acc (optional),
+  // operands[3] = scale_a (optional), operands[4] = scale_b (optional)
+
+  if (layoutA && operands.size() > 0)
+    propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(layoutA)));
+
+  if (layoutB && operands.size() > 1)
+    propagateIfChanged(operands[1], operands[1]->meet(LayoutInfo(layoutB)));
+
+  if (layoutCD && operands.size() > 2)
+    propagateIfChanged(operands[2], operands[2]->meet(LayoutInfo(layoutCD)));
+
+  if (layoutAScale && operands.size() > 3)
+    propagateIfChanged(operands[3], operands[3]->meet(LayoutInfo(layoutAScale)));
+
+  if (layoutBScale && operands.size() > 4)
+    propagateIfChanged(operands[4], operands[4]->meet(LayoutInfo(layoutBScale)));
+}
     dpas.setLayoutBAttr(requiredBLayout);
     dpas.setLayoutCdAttr(requiredCDLayoutAttr);
     dpasALayout = LayoutInfo(requiredALayout);

>From bbc8611ee119e46a6bb23adb1678f1b258bb40a8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 23 Apr 2026 23:44:35 +0000
Subject: [PATCH 2/8] [mlir][XeGPU] Add workgroup-to-subgroup distribution for
 DpasMx ops

This patch implements workgroup-to-subgroup distribution for DpasMx
(mixed-precision DPAS with scale factors) operations.

Key changes:
- WgToSgDpasMxOp pattern in XeGPUWgToSgDistribute.cpp: Distributes DpasMx
  operations from workgroup to subgroup level by independently distributing
  each operand (A matrix, B matrix, accumulator C/D, scale_a, scale_b) based
  on their layout attributes
- Helper utilities in XeGPUUtils.cpp for layout-based distribution
- Comprehensive tests in xegpu-wg-to-sg-unify-ops.mlir

The pattern handles optional operands (accumulator and scale factors) and
ensures proper layout propagation for MXFP matrix multiply operations across
subgroups.

Depends on: PR #3 (DpasMx op definition and layout support)

Co-Authored-By: Claude Sonnet 4.5 <noreply at anthropic.com>
---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 235 ++++++++++++++++--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  26 ++
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  42 ++++
 3 files changed, 287 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index e083507173d31..951b323af6b05 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -22,8 +22,11 @@
 #include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
 #include <optional>
 
+#define DEBUG_TYPE "xegpu-wg-to-sg-distribute"
+
 namespace mlir {
 namespace xegpu {
 #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
@@ -455,6 +458,93 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
   }
 };
 
+/// This pattern transforms the DpasMxOp to work at subgroup level.
+struct WgToSgDpasMxOp : public OpConversionPattern<xegpu::DpasMxOp> {
+  using OpConversionPattern<xegpu::DpasMxOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::DpasMxOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType resultTy = op.getResult().getType();
+
+    LLVM_DEBUG(llvm::dbgs() << "WgToSgDpasMxOp: original op: " << op << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "  resultTy: " << resultTy << "\n");
+
+    if (resultTy.getRank() != 2)
+      return failure();
+
+    auto layoutCd = op.getLayoutCdAttr();
+    auto layoutA = op.getLayoutAAttr();
+    auto layoutB = op.getLayoutBAttr();
+    auto layoutAScale = op.getLayoutAScaleAttr();
+    auto layoutBScale = op.getLayoutBScaleAttr();
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "  adaptor.getA() size: " << adaptor.getA().size() << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "  adaptor.getB() size: " << adaptor.getB().size() << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "  adaptor.getAcc() size: "
+                            << adaptor.getAcc().size() << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "  adaptor.getScaleA() size: "
+                            << adaptor.getScaleA().size() << "\n");
+    LLVM_DEBUG(llvm::dbgs() << "  adaptor.getScaleB() size: "
+                            << adaptor.getScaleB().size() << "\n");
+
+    if (!layoutCd || !layoutA || !layoutB || !layoutAScale || !layoutBScale)
+      return failure();
+
+    size_t index_c = 0;
+    SmallVector<Value> newDpasMxOps;
+    for (auto [index_a, aVec] : llvm::enumerate(adaptor.getA())) {
+      for (auto [index_b, bVec] : llvm::enumerate(adaptor.getB())) {
+        LLVM_DEBUG(llvm::dbgs() << "  index_a=" << index_a << " aVec: "
+                                << aVec.getType() << ", index_b=" << index_b
+                                << " bVec: " << bVec.getType() << "\n");
+
+        Value accVal;
+        if (op.getAcc()) {
+          accVal = adaptor.getAcc()[index_c++];
+          LLVM_DEBUG(llvm::dbgs() << "    acc[" << (index_c - 1)
+                                  << "]: " << accVal.getType() << "\n");
+        }
+        Value scaleAVal;
+        if (op.getScaleA()) {
+          scaleAVal = adaptor.getScaleA()[index_a];
+          LLVM_DEBUG(llvm::dbgs() << "    scaleA[" << index_a
+                                  << "]: " << scaleAVal.getType() << "\n");
+        }
+        Value scaleBVal;
+        if (op.getScaleB()) {
+          scaleBVal = adaptor.getScaleB()[index_b];
+          LLVM_DEBUG(llvm::dbgs() << "    scaleB[" << index_b
+                                  << "]: " << scaleBVal.getType() << "\n");
+        }
+
+        ArrayRef<int64_t> aVecShape =
+            llvm::cast<VectorType>(aVec.getType()).getShape();
+        ArrayRef<int64_t> bVecShape =
+            llvm::cast<VectorType>(bVec.getType()).getShape();
+        VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
+                                           resultTy.getElementType());
+        LLVM_DEBUG(llvm::dbgs() << "    resTy: " << resTy << "\n");
+        auto newDpasMxOp = xegpu::DpasMxOp::create(
+            rewriter, loc, resTy, aVec, bVec, accVal, scaleAVal, scaleBVal,
+            layoutA.dropSgLayoutAndData(), layoutB.dropSgLayoutAndData(),
+            layoutCd.dropSgLayoutAndData(), layoutAScale.dropSgLayoutAndData(),
+            layoutBScale.dropSgLayoutAndData());
+        LLVM_DEBUG(llvm::dbgs() << "    created: " << newDpasMxOp << "\n");
+
+        newDpasMxOps.push_back(newDpasMxOp);
+      }
+    }
+    LLVM_DEBUG(llvm::dbgs()
+               << "  total new DpasMxOps: " << newDpasMxOps.size() << "\n");
+    rewriter.replaceOpWithMultiple(op, {newDpasMxOps});
+    return success();
+  }
+};
+
 /// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
 struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
@@ -1547,23 +1637,122 @@ struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
 
 using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
 using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
+
+// This pattern transforms vector.bitcast ops to work at subgroup level.
+struct WgToSgVectorBitCastOp : public OpConversionPattern<vector::BitCastOp> {
+  using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BitCastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newBitCastOps;
+    for (auto src : adaptor.getSource()) {
+      auto newBitCast =
+          vector::BitCastOp::create(rewriter, op.getLoc(), newResultType, src);
+      newBitCastOps.push_back(newBitCast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newBitCastOps});
+    return success();
+  }
+};
+
+// This pattern transforms vector.interleave ops to work at subgroup level.
+struct WgToSgVectorInterleaveOp
+    : public OpConversionPattern<vector::InterleaveOp> {
+  using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newInterleaveOps;
+    // Interleave operates pairwise: each lhs value is interleaved with
+    // corresponding rhs value
+    for (auto [lhs, rhs] : llvm::zip(adaptor.getLhs(), adaptor.getRhs())) {
+      auto newInterleave = vector::InterleaveOp::create(
+          rewriter, op.getLoc(), newResultType, lhs, rhs);
+      newInterleaveOps.push_back(newInterleave.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newInterleaveOps});
+    return success();
+  }
+};
+
+// This pattern transforms vector.deinterleave ops to work at subgroup level.
+struct WgToSgVectorDeinterleaveOp
+    : public OpConversionPattern<vector::DeinterleaveOp> {
+  using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getRes1()));
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<Value> newRes1Ops;
+    SmallVector<Value> newRes2Ops;
+    // Deinterleave produces two results from each source
+    for (auto src : adaptor.getSource()) {
+      auto newDeinterleave =
+          vector::DeinterleaveOp::create(rewriter, op.getLoc(), src);
+      newRes1Ops.push_back(newDeinterleave.getRes1());
+      newRes2Ops.push_back(newDeinterleave.getRes2());
+    }
+
+    // Combine both result sets for replacement
+    SmallVector<SmallVector<Value>> results = {newRes1Ops, newRes2Ops};
+    rewriter.replaceOpWithMultiple(op, results);
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
-  patterns
-      .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
-           WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
-           WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-           WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
-           WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
-           WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
-           WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
-           WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
-           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
-           WgToSgVectorConstantMaskOp, WgToSgVectorCreateMaskOp>(
-          patterns.getContext());
+  patterns.add<
+      WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+      WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
+      WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgDpasMxOp, WgToSgPrefetchNdOp,
+      WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
+      WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+      WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
+      WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
+      WgToSgVectorStepOp, WgToSgVectorShapeCastOp, WgToSgMultiDimReductionOp,
+      WgToSgVectorTransposeOp, WgToSgVectorConstantMaskOp,
+      WgToSgVectorCreateMaskOp, WgToSgVectorBitCastOp, WgToSgVectorInterleaveOp,
+      WgToSgVectorDeinterleaveOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -1687,6 +1876,12 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
+  target.addDynamicallyLegalOp<xegpu::DpasMxOp>(
+      [=](xegpu::DpasMxOp op) -> bool {
+        auto layout = op.getLayoutCdAttr();
+        return isLegal(layout);
+      });
+
   target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
       [=](xegpu::LoadMatrixOp op) -> bool {
         return isLegal(op.getLayoutAttr());
@@ -1708,10 +1903,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
-                               vector::TransposeOp, vector::BroadcastOp,
-                               vector::MultiDimReductionOp,
-                               vector::ConstantMaskOp, vector::CreateMaskOp>(
+  target.addDynamicallyLegalOp<
+      vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
+      vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp,
+      vector::CreateMaskOp, vector::BitCastOp, vector::InterleaveOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout =
@@ -1719,6 +1914,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
+  target.addDynamicallyLegalOp<vector::DeinterleaveOp>(
+      [=](vector::DeinterleaveOp op) -> bool {
+        // DeinterleaveOp has two results, check the first one
+        auto layout =
+            xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getRes1()));
+        return isLegal(layout);
+      });
+
   target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
       [=](xegpu::LoadGatherOp op) -> bool {
         auto layout = op.getLayoutAttr();
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2d1ce6eea17aa..fb0b5849fa46e 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -183,6 +183,32 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
         return dpasOp.getLayoutCdAttr();
       }
     }
+    if (auto dpasMxOp = dyn_cast<xegpu::DpasMxOp>(op)) {
+      // DpasMxOp has operands: a, b, optional acc, optional scale_a, optional scale_b
+      // Use AttrSizedOperandSegments to determine which operand this is
+      auto segmentSizesAttr = dpasMxOp->getAttrOfType<DenseI32ArrayAttr>(
+          dpasMxOp.getOperandSegmentSizesAttrName());
+      if (!segmentSizesAttr)
+        return nullptr;
+
+      auto segmentSizes = segmentSizesAttr.asArrayRef();
+      unsigned aSize = segmentSizes[0];
+      unsigned bSize = segmentSizes[1];
+      unsigned accSize = segmentSizes[2];
+      unsigned scaleASize = segmentSizes[3];
+
+      if (idx < aSize) {
+        return dpasMxOp.getLayoutAAttr();
+      } else if (idx < aSize + bSize) {
+        return dpasMxOp.getLayoutBAttr();
+      } else if (idx < aSize + bSize + accSize) {
+        return dpasMxOp.getLayoutCdAttr();
+      } else if (idx < aSize + bSize + accSize + scaleASize) {
+        return dpasMxOp.getLayoutAScaleAttr();
+      } else {
+        return dpasMxOp.getLayoutBScaleAttr();
+      }
+    }
     if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
       return convertOp.getInputLayoutAttr();
     }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 3bc43b780ade2..80dd6d7b94151 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -960,4 +960,46 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: @bitcast_distribution
+  gpu.func @bitcast_distribution(%src: memref<256x128xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    // CHECK: vector.bitcast {{.*}} : vector<32x32xf32> to vector<32x64xi16>
+    %bitcast = vector.bitcast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<256x128xf32> to vector<256x256xi16>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @interleave_distribution
+  gpu.func @interleave_distribution(%src: memref<256x128xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load1 =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    %load2 =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    // CHECK: vector.interleave {{.*}}, {{.*}} : vector<32x32xf32>
+    %interleave = vector.interleave %load1, %load2 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<256x128xf32> -> vector<256x256xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: @deinterleave_distribution
+  gpu.func @deinterleave_distribution(%src: memref<256x256xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x256xf32>
+      -> !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<256x256xf32>
+    // CHECK: {{.*}} = vector.deinterleave {{.*}} : vector<32x64xf32> -> vector<32x32xf32>
+    %deinterleave:2 = vector.deinterleave %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_1 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<256x256xf32> -> vector<256x128xf32>
+    gpu.return
+  }
+
 }

>From 159040f0267c0e48341ffec11fa37695516716c7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 29 Apr 2026 22:49:17 +0000
Subject: [PATCH 3/8] fix compilation issue

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 22 +++++++++----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index f1709e95d34a9..9f4227e694688 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -810,6 +810,17 @@ void LayoutInfoPropagation::visitDpasOp(
     std::tie(requiredALayout, requiredBLayout, requiredCDLayoutAttr) = *layouts;
 
     dpas.setLayoutAAttr(requiredALayout);
+    dpas.setLayoutBAttr(requiredBLayout);
+    dpas.setLayoutCdAttr(requiredCDLayoutAttr);
+    dpasALayout = LayoutInfo(requiredALayout);
+    dpasBLayout = LayoutInfo(requiredBLayout);
+    dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
+  }
+  propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
+  propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
+  if (operands.size() > 2)
+    propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
+}
 
 /// Propagate layout for DpasMxOp operands using the layout attributes.
 /// DpasMxOp has operands: a, b, acc (optional), scale_a (optional), scale_b (optional)
@@ -843,17 +854,6 @@ void LayoutInfoPropagation::visitDpasMxOp(
   if (layoutBScale && operands.size() > 4)
     propagateIfChanged(operands[4], operands[4]->meet(LayoutInfo(layoutBScale)));
 }
-    dpas.setLayoutBAttr(requiredBLayout);
-    dpas.setLayoutCdAttr(requiredCDLayoutAttr);
-    dpasALayout = LayoutInfo(requiredALayout);
-    dpasBLayout = LayoutInfo(requiredBLayout);
-    dpasCDLayout = LayoutInfo(requiredCDLayoutAttr);
-  }
-  propagateIfChanged(operands[0], operands[0]->meet(dpasALayout));
-  propagateIfChanged(operands[1], operands[1]->meet(dpasBLayout));
-  if (operands.size() > 2)
-    propagateIfChanged(operands[2], operands[2]->meet(dpasCDLayout));
-}
 
 /// Set the layout for the value and tensor descriptor operands in StoreNdOp.
 void LayoutInfoPropagation::visitStoreNdOp(

>From 76a0fc2b33124532038b7bc56c99f2f558264041 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 29 Apr 2026 23:46:58 +0000
Subject: [PATCH 4/8] add test

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 34 +------------------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 34 +++++++++++++++++++
 2 files changed, 35 insertions(+), 33 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6e6ae9909139b..555e1badf6661 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -356,9 +356,6 @@ struct WgToSgDpasMxOp : public OpConversionPattern<xegpu::DpasMxOp> {
     Location loc = op.getLoc();
     VectorType resultTy = op.getResult().getType();
 
-    LLVM_DEBUG(llvm::dbgs() << "WgToSgDpasMxOp: original op: " << op << "\n");
-    LLVM_DEBUG(llvm::dbgs() << "  resultTy: " << resultTy << "\n");
-
     if (resultTy.getRank() != 2)
       return failure();
 
@@ -368,17 +365,6 @@ struct WgToSgDpasMxOp : public OpConversionPattern<xegpu::DpasMxOp> {
     auto layoutAScale = op.getLayoutAScaleAttr();
     auto layoutBScale = op.getLayoutBScaleAttr();
 
-    LLVM_DEBUG(llvm::dbgs()
-               << "  adaptor.getA() size: " << adaptor.getA().size() << "\n");
-    LLVM_DEBUG(llvm::dbgs()
-               << "  adaptor.getB() size: " << adaptor.getB().size() << "\n");
-    LLVM_DEBUG(llvm::dbgs() << "  adaptor.getAcc() size: "
-                            << adaptor.getAcc().size() << "\n");
-    LLVM_DEBUG(llvm::dbgs() << "  adaptor.getScaleA() size: "
-                            << adaptor.getScaleA().size() << "\n");
-    LLVM_DEBUG(llvm::dbgs() << "  adaptor.getScaleB() size: "
-                            << adaptor.getScaleB().size() << "\n");
-
     if (!layoutCd || !layoutA || !layoutB || !layoutAScale || !layoutBScale)
       return failure();
 
@@ -386,27 +372,17 @@ struct WgToSgDpasMxOp : public OpConversionPattern<xegpu::DpasMxOp> {
     SmallVector<Value> newDpasMxOps;
     for (auto [index_a, aVec] : llvm::enumerate(adaptor.getA())) {
       for (auto [index_b, bVec] : llvm::enumerate(adaptor.getB())) {
-        LLVM_DEBUG(llvm::dbgs() << "  index_a=" << index_a << " aVec: "
-                                << aVec.getType() << ", index_b=" << index_b
-                                << " bVec: " << bVec.getType() << "\n");
-
         Value accVal;
         if (op.getAcc()) {
           accVal = adaptor.getAcc()[index_c++];
-          LLVM_DEBUG(llvm::dbgs() << "    acc[" << (index_c - 1)
-                                  << "]: " << accVal.getType() << "\n");
         }
         Value scaleAVal;
         if (op.getScaleA()) {
           scaleAVal = adaptor.getScaleA()[index_a];
-          LLVM_DEBUG(llvm::dbgs() << "    scaleA[" << index_a
-                                  << "]: " << scaleAVal.getType() << "\n");
         }
         Value scaleBVal;
         if (op.getScaleB()) {
           scaleBVal = adaptor.getScaleB()[index_b];
-          LLVM_DEBUG(llvm::dbgs() << "    scaleB[" << index_b
-                                  << "]: " << scaleBVal.getType() << "\n");
         }
 
         ArrayRef<int64_t> aVecShape =
@@ -415,19 +391,15 @@ struct WgToSgDpasMxOp : public OpConversionPattern<xegpu::DpasMxOp> {
             llvm::cast<VectorType>(bVec.getType()).getShape();
         VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
                                            resultTy.getElementType());
-        LLVM_DEBUG(llvm::dbgs() << "    resTy: " << resTy << "\n");
         auto newDpasMxOp = xegpu::DpasMxOp::create(
             rewriter, loc, resTy, aVec, bVec, accVal, scaleAVal, scaleBVal,
             layoutA.dropSgLayoutAndData(), layoutB.dropSgLayoutAndData(),
             layoutCd.dropSgLayoutAndData(), layoutAScale.dropSgLayoutAndData(),
             layoutBScale.dropSgLayoutAndData());
-        LLVM_DEBUG(llvm::dbgs() << "    created: " << newDpasMxOp << "\n");
 
         newDpasMxOps.push_back(newDpasMxOp);
       }
     }
-    LLVM_DEBUG(llvm::dbgs()
-               << "  total new DpasMxOps: " << newDpasMxOps.size() << "\n");
     rewriter.replaceOpWithMultiple(op, {newDpasMxOps});
     return success();
   }
@@ -1567,9 +1539,6 @@ struct WgToSgVectorDeinterleaveOp
   LogicalResult
   matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = op.getResultVectorType();
-
-    ArrayRef<int64_t> wgShape = resultType.getShape();
     xegpu::DistributeLayoutAttr layout =
         xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getRes1()));
     if (!layout || !layout.isForWorkgroup())
@@ -1577,7 +1546,7 @@ struct WgToSgVectorDeinterleaveOp
 
     SmallVector<Value> newRes1Ops;
     SmallVector<Value> newRes2Ops;
-    // Deinterleave produces two results from each source
+
     for (auto src : adaptor.getSource()) {
       auto newDeinterleave =
           vector::DeinterleaveOp::create(rewriter, op.getLoc(), src);
@@ -1585,7 +1554,6 @@ struct WgToSgVectorDeinterleaveOp
       newRes2Ops.push_back(newDeinterleave.getRes2());
     }
 
-    // Combine both result sets for replacement
     SmallVector<SmallVector<Value>> results = {newRes1Ops, newRes2Ops};
     rewriter.replaceOpWithMultiple(op, results);
     return success();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index b43a387aa2809..88d2010876ce6 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -93,6 +93,40 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: dpas_mx
+  gpu.func @dpas_mx(%a: memref<128x128xf8E5M2>, %b: memref<128x128xf8E5M2>, %a_scale: memref<128x4xf8E8M0FNU>, %b_scale: memref<4x128xf8E8M0FNU>) {
+    // CHECK: %[[DPAS_MX:.*]] = xegpu.dpas_mx %{{.*}}, %{{.*}}, %{{.*}} scale_a = %{{.*}} scale_b = %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_a_scale = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_b_scale = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf8E5M2>, vector<128x16xf8E5M2>, vector<16x16xbf16>, vector<16x4xf8E8M0FNU>, vector<4x16xf8E8M0FNU> -> vector<16x16xbf16>
+    %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf8E5M2>
+      -> !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<128x128xf8E5M2>
+    %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf8E5M2>
+      -> !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>}
+      : !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
+      -> vector<128x128xf8E5M2>
+    %tdesc_a_scale = xegpu.create_nd_tdesc %a_scale : memref<128x4xf8E8M0FNU>
+      -> !xegpu.tensor_desc<128x4xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load_a_scale =  xegpu.load_nd %tdesc_a_scale[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<128x4xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<128x4xf8E8M0FNU>
+    %tdesc_b_scale = xegpu.create_nd_tdesc %b_scale : memref<4x128xf8E8M0FNU>
+      -> !xegpu.tensor_desc<4x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load_b_scale =  xegpu.load_nd %tdesc_b_scale[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<4x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> vector<4x128xf8E8M0FNU>
+    %cst = arith.constant dense<0.0> : vector<128x128xbf16>
+    %dpas_mx = xegpu.dpas_mx %load_a, %load_b, %cst scale_a = %load_a_scale scale_b = %load_b_scale
+       {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>,
+        layout_cd =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_b_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
+      : vector<128x128xf8E5M2>, vector<128x128xf8E5M2>, vector<128x128xbf16>, vector<128x4xf8E8M0FNU>, vector<4x128xf8E8M0FNU> -> vector<128x128xbf16>
+    gpu.return
+  }
+
   // CHECK-LABEL: dpas_no_sg_data
   gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>

>From a9d7a495bb0caf9a8ffcd958a0c9533f9478de26 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 30 Apr 2026 00:06:40 +0000
Subject: [PATCH 5/8] fix test

---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 88d2010876ce6..fe5e9d014fb87 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -93,13 +93,13 @@ gpu.module @test_distribution {
     gpu.return
   }
 
-  // CHECK-LABEL: dpas_mx
+   // CHECK-LABEL: dpas_mx
   gpu.func @dpas_mx(%a: memref<128x128xf8E5M2>, %b: memref<128x128xf8E5M2>, %a_scale: memref<128x4xf8E8M0FNU>, %b_scale: memref<4x128xf8E8M0FNU>) {
-    // CHECK: %[[DPAS_MX:.*]] = xegpu.dpas_mx %{{.*}}, %{{.*}}, %{{.*}} scale_a = %{{.*}} scale_b = %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_a_scale = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_b_scale = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf8E5M2>, vector<128x16xf8E5M2>, vector<16x16xbf16>, vector<16x4xf8E8M0FNU>, vector<4x16xf8E8M0FNU> -> vector<16x16xbf16>
+    // CHECK: %[[DPAS_MX:.*]] = xegpu.dpas_mx %{{.*}}, %{{.*}}, %{{.*}} scale_a = %{{.*}} scale_b = %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>, layout_a_scale = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_b_scale = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf8E5M2>, vector<128x16xf8E5M2>, vector<16x16xbf16>, vector<16x4xf8E8M0FNU>, vector<4x16xf8E8M0FNU> -> vector<16x16xbf16>
     %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf8E5M2>
-      -> !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
-    %load_a =  xegpu.load_nd %tdesc_a[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>}
-      : !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 2]>>
+    %load_a =  xegpu.load_nd %tdesc_a[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 2]>}
+      : !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 2]>>
       -> vector<128x128xf8E5M2>
     %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf8E5M2>
       -> !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
@@ -107,9 +107,9 @@ gpu.module @test_distribution {
       : !xegpu.tensor_desc<128x128xf8E5M2, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
       -> vector<128x128xf8E5M2>
     %tdesc_a_scale = xegpu.create_nd_tdesc %a_scale : memref<128x4xf8E8M0FNU>
-      -> !xegpu.tensor_desc<128x4xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [1, 16], lane_data = [1, 1]>>
-    %load_a_scale =  xegpu.load_nd %tdesc_a_scale[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [1, 16], lane_data = [1, 1]>}
-      : !xegpu.tensor_desc<128x4xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> !xegpu.tensor_desc<128x4xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [16, 1], lane_data = [1, 1]>>
+    %load_a_scale =  xegpu.load_nd %tdesc_a_scale[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [16, 1], lane_data = [1, 1]>}
+      : !xegpu.tensor_desc<128x4xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [16, 1], lane_data = [1, 1]>>
       -> vector<128x4xf8E8M0FNU>
     %tdesc_b_scale = xegpu.create_nd_tdesc %b_scale : memref<4x128xf8E8M0FNU>
       -> !xegpu.tensor_desc<4x128xf8E8M0FNU, #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -118,10 +118,10 @@ gpu.module @test_distribution {
       -> vector<4x128xf8E8M0FNU>
     %cst = arith.constant dense<0.0> : vector<128x128xbf16>
     %dpas_mx = xegpu.dpas_mx %load_a, %load_b, %cst scale_a = %load_a_scale scale_b = %load_b_scale
-       {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>,
+       {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 2]>,
         layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>,
         layout_cd =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>,
-        layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [1, 16], lane_data = [1, 1]>,
+        layout_a_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 4], lane_layout = [16, 1], lane_data = [1, 1]>,
         layout_b_scale = #xegpu.layout<sg_layout = [8, 8], sg_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
       : vector<128x128xf8E5M2>, vector<128x128xf8E5M2>, vector<128x128xbf16>, vector<128x4xf8E8M0FNU>, vector<4x128xf8E8M0FNU> -> vector<128x128xbf16>
     gpu.return

>From 815478213affd55886bd48477835c23e1bd22433 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 30 Apr 2026 00:10:44 +0000
Subject: [PATCH 6/8] remove debug support

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 555e1badf6661..1706bab27fe29 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -22,11 +22,8 @@
 #include "mlir/Dialect/XeGPU/Transforms/XeGPULayoutImpl.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/Debug.h"
 #include <optional>
 
-#define DEBUG_TYPE "xegpu-wg-to-sg-distribute"
-
 namespace mlir {
 namespace xegpu {
 #define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE

>From 02c2094be97c65eadd3de7efe529db0e255b3d41 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 30 Apr 2026 21:54:42 +0000
Subject: [PATCH 7/8] address format issue

---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index fb0b5849fa46e..41c4b2173eb38 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -184,8 +184,8 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
       }
     }
     if (auto dpasMxOp = dyn_cast<xegpu::DpasMxOp>(op)) {
-      // DpasMxOp has operands: a, b, optional acc, optional scale_a, optional scale_b
-      // Use AttrSizedOperandSegments to determine which operand this is
+      // DpasMxOp has operands: a, b, optional acc, optional scale_a, optional
+      // scale_b Use AttrSizedOperandSegments to determine which operand this is
       auto segmentSizesAttr = dpasMxOp->getAttrOfType<DenseI32ArrayAttr>(
           dpasMxOp.getOperandSegmentSizesAttrName());
       if (!segmentSizesAttr)

>From 19630839cc0e981ec8576860582a7a1c724c0f3f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 6 May 2026 16:18:20 +0000
Subject: [PATCH 8/8] fix deinterleave two result issue and test issue

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |  5 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 30 +++++++++++-
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 46 +++++++++++--------
 3 files changed, 59 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 820ec11ee4e95..4cab1e24bf9e6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -127,9 +127,10 @@ static xegpu::DistributeLayoutAttr getLayoutFromUsePoints(Value result) {
 // For regular operations: First the result layouts are propagated from uses.
 // Then the result layouts are propagated to uses (operands).
 static void propagateResultsToRegularOperands(Operation *op) {
-  if (op->getNumResults() == 0 || op->getNumResults() > 1)
+  if (op->getNumResults() == 0)
+    return;
+  if (op->getNumResults() > 1 && !isa<vector::DeinterleaveOp>(op))
     return;
-
   OpResult result = op->getResult(0);
   xegpu::DistributeLayoutAttr resLayout = getLayoutFromUsePoints(result);
   Type resultType = result.getType();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 1706bab27fe29..c44768d7a03d2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1536,10 +1536,28 @@ struct WgToSgVectorDeinterleaveOp
   LogicalResult
   matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    llvm::dbgs() << "[DEBUG WgToSgVectorDeinterleaveOp] ENTRY\n";
+    llvm::dbgs() << "[DEBUG WgToSgVectorDeinterleaveOp] deinterleave op: " << op
+                 << "\n";
+
     xegpu::DistributeLayoutAttr layout =
         xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getRes1()));
-    if (!layout || !layout.isForWorkgroup())
+    llvm::dbgs() << "[DEBUG WgToSgVectorDeinterleaveOp] layout: " << layout
+                 << "\n";
+    llvm::dbgs() << "[DEBUG WgToSgVectorDeinterleaveOp] layout is null: "
+                 << (!layout) << "\n";
+    if (layout)
+      llvm::dbgs()
+          << "[DEBUG WgToSgVectorDeinterleaveOp] layout.isForWorkgroup(): "
+          << layout.isForWorkgroup() << "\n";
+
+    if (!layout || !layout.isForWorkgroup()) {
+      llvm::dbgs() << "[DEBUG WgToSgVectorDeinterleaveOp] FAILURE: no "
+                      "workgroup layout\n";
       return failure();
+    }
+    llvm::dbgs() << "[DEBUG WgToSgVectorDeinterleaveOp] About to process "
+                 << adaptor.getSource().size() << " sources\n";
 
     SmallVector<Value> newRes1Ops;
     SmallVector<Value> newRes2Ops;
@@ -1736,9 +1754,17 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   target.addDynamicallyLegalOp<vector::DeinterleaveOp>(
       [=](vector::DeinterleaveOp op) -> bool {
         // DeinterleaveOp has two results, check the first one
+        llvm::dbgs()
+            << "[DEBUG Legality Check] Checking vector.deinterleave legality\n";
+        llvm::dbgs() << "[DEBUG Legality Check] deinterleave op: " << op
+                     << "\n";
         auto layout =
             xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getRes1()));
-        return isLegal(layout);
+        llvm::dbgs() << "[DEBUG Legality Check] layout: " << layout << "\n";
+        bool legal = isLegal(layout);
+        llvm::dbgs() << "[DEBUG Legality Check] isLegal(layout) = " << legal
+                     << "\n";
+        return legal;
       });
 
   target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index f61fc0f2684fc..e2b6cc37e1c2a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1303,42 +1303,52 @@ gpu.module @test_distribution {
   // CHECK-LABEL: @bitcast_distribution
   gpu.func @bitcast_distribution(%src: memref<256x128xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-    %load =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
-      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> !xegpu.tensor_desc<256x128xf32>
+    %load =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}
+      : !xegpu.tensor_desc<256x128xf32>
       -> vector<256x128xf32>
     // CHECK: vector.bitcast {{.*}} : vector<32x32xf32> to vector<32x64xi16>
-    %bitcast = vector.bitcast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>}
-      : vector<256x128xf32> to vector<256x256xi16>
+    %bitcast = vector.bitcast %load : vector<256x128xf32> to vector<256x256xi16>
+    %anchor = xegpu.convert_layout %bitcast
+      <{
+        input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>,
+        target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>
+      }> : vector<256x256xi16>
     gpu.return
   }
 
   // CHECK-LABEL: @interleave_distribution
   gpu.func @interleave_distribution(%src: memref<256x128xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
-      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-    %load1 =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
-      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+      -> !xegpu.tensor_desc<256x128xf32>
+    %load1 =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}
+      : !xegpu.tensor_desc<256x128xf32>
       -> vector<256x128xf32>
-    %load2 =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
-      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %load2 =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>}
+      : !xegpu.tensor_desc<256x128xf32>
       -> vector<256x128xf32>
     // CHECK: vector.interleave {{.*}}, {{.*}} : vector<32x32xf32>
-    %interleave = vector.interleave %load1, %load2 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>}
+    %interleave = vector.interleave %load1, %load2
       : vector<256x128xf32> -> vector<256x256xf32>
+    %anchor = xegpu.convert_layout %interleave
+      <{
+        input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>,
+        target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>
+      }> : vector<256x256xf32>
     gpu.return
   }
 
   // CHECK-LABEL: @deinterleave_distribution
   gpu.func @deinterleave_distribution(%src: memref<256x256xf32>) {
-    %tdesc = xegpu.create_nd_tdesc %src : memref<256x256xf32>
-      -> !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>>
-    %load =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>}
-      : !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>>
-      -> vector<256x256xf32>
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x256xf32> -> !xegpu.tensor_desc<256x256xf32>
+    %load =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>} : !xegpu.tensor_desc<256x256xf32> -> vector<256x256xf32>
     // CHECK: {{.*}} = vector.deinterleave {{.*}} : vector<32x64xf32> -> vector<32x32xf32>
-    %deinterleave:2 = vector.deinterleave %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_1 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
-      : vector<256x256xf32> -> vector<256x128xf32>
+    %deinterleave:2 = vector.deinterleave %load : vector<256x256xf32> -> vector<256x128xf32>
+    %anchor = xegpu.convert_layout %deinterleave#0
+      <{
+        input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>,
+        target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
+      }> : vector<256x128xf32>
     gpu.return
   }
 



More information about the Mlir-commits mailing list