[Mlir-commits] [mlir] [mlir][XeGPU] add support for SCF control ops in workgroup to subgroup distribution (PR #142612)

Chao Chen llvmlistbot at llvm.org
Tue Jun 3 07:26:21 PDT 2025


https://github.com/chencha3 created https://github.com/llvm/llvm-project/pull/142612

None

>From 1ed4cb5b381898728f850da43a10826493fce94b Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sat, 10 May 2025 17:04:39 +0000
Subject: [PATCH 01/55] Add XeGPUWgToSg pass

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  31 +-
 .../Dialect/XeGPU/Transforms/Transforms.h     |   4 +
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   1 +
 .../Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp  | 374 ++++++++++++++++++
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |  65 +++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  81 ++++
 6 files changed, 544 insertions(+), 12 deletions(-)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 3e81f2d0ed786..bdea88cfd7022 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
 
@@ -18,9 +17,7 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
     The pass folds aliasing ops into XeGPU ops that they operate on the original
     source references.
   }];
-  let dependentDialects = [
-      "memref::MemRefDialect", "xegpu::XeGPUDialect"
-  ];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect"];
 }
 
 def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
@@ -28,14 +25,24 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   let description = [{
     The pass distributes subgroup level (SIMD) XeGPU ops to work items.
   }];
-  let dependentDialects = [
-      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
-  ];
-  let options = [
-    Option<"printOnly", "print-analysis-only", "bool",
-         /*default=*/"false",
-         "Print the result of the subgroup map propagation analysis and exit.">
-  ];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect"];
+  let options = [Option<
+      "printOnly", "print-analysis-only", "bool",
+      /*default=*/"false",
+      "Print the result of the subgroup map propagation analysis and exit.">];
+}
+
+def XeGPUWgToSg : Pass<"xegpu-wg-to-sg", "::mlir::gpu::GPUModuleOp"> {
+  let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
+  let description = [{
+    This transform pass distributes the workgroup level computation to
+    multiple subgroups based on the sg_layout and sg_data attributes.
+  }];
+
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect", "arith::ArithDialect",
+                           "gpu::GPUDialect", "index::IndexDialect"];
 }
 
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 3e94021c7a1ea..388ba32e1eebb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Transforms/DialectConversion.h"
+
 namespace mlir {
 class RewritePatternSet;
 
@@ -18,6 +20,8 @@ namespace xegpu {
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns,
+                                 ConversionTarget &target);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 901e02d3c9cf5..b258921cc87fd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
   XeGPUSubgroupDistribute.cpp
+  XeGPUWgToSg.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
new file mode 100644
index 0000000000000..7969d37d67f04
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -0,0 +1,374 @@
+//===- XeGPUWgToSg.cpp - XeGPU WorkGroup to Subgroup Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+#include <mlir/Dialect/GPU/IR/GPUDialect.h>
+#include <mlir/Dialect/Index/IR/IndexOps.h>
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUWGTOSG
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-wg-to-sg"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+// clang-format off
+/// This pattern transform the CreateNdDescOp to create a subgroup descriptor
+/// from a workgroup descriptor. It replaces the offsets and sizes with
+/// appropriate values for the subgroup.
+/// It uses round-robin distribution to create the subgroup descriptor.
+
+/// Following create_nd_desc operation:,
+///    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
+///       -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
+///           sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+/// is converted to 9 subgroup level operations based on the sg_layout & sg_data:
+///    %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
+///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+///
+/// The sg_layout and sg_data are dropped from the layout attribute as they are no longer needed.
+///
+/// 24x24 matrix distribution example:
+/// sg_layout = [4, 4], sg_data = [2, 2]
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
+///
+/// +------------------------+
+/// | 8x8 | 8x8 | 8x8 |      <- 3 tiles across
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |      <- 3 tiles down
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |
+/// +------------------------+
+///
+/// Each 8x8 tile is further subdivided among subgroups:
+/// +------------------------+
+/// | 2x2 2x2 2x2 2x2 |  <- 4 subgroups across (each handles 2 columns)
+/// | 2x2 2x2 2x2 2x2 |  <- 4 subgroups down (each handles 2 rows)
+/// | 2x2 2x2 2x2 2x2 |
+/// | 2x2 2x2 2x2 2x2 |
+/// +------------------------+
+///
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
+/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+// clang-format on
+struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+  // Helper to extract mixed offsets into a Value array
+  SmallVector<Value> extractOffsets(ConversionPatternRewriter &rewriter,
+                                    xegpu::CreateNdDescOp op) const {
+    llvm::SmallVector<Value> offsets;
+    auto staticOffsets = op.getStaticOffsets();
+    auto dynamicOffsets = op.getOffsets();
+
+    for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
+      if (ShapedType::isDynamic(staticOffsets[i])) {
+        offsets.push_back(dynamicOffsets[j++]);
+      } else {
+        offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
+            op.getLoc(), staticOffsets[i]));
+      }
+    }
+    return offsets;
+  }
+
+  // Convert linear subgroup ID to 2D coordinates
+  // TODO: Delinearize for nD
+  SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
+                                           Location loc, Value sgID,
+                                           Value sgDimX, Value sgDimY) const {
+    return {rewriter.create<index::DivUOp>(loc, sgID, sgDimY),
+            rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
+  }
+
+  // Create a constant index value
+  Value createConstantIndex(ConversionPatternRewriter &rewriter, Location loc,
+                            int64_t value) const {
+    return rewriter.create<arith::ConstantIndexOp>(loc, value);
+  }
+
+  // Calculate global offset for each subgroup
+  SmallVector<OpFoldResult>
+  calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
+                         const SmallVector<Value> &originalOffsets,
+                         const SmallVector<Value> &localOffset,
+                         const SmallVector<int64_t> &distUnitBaseAddr) const {
+
+    Value constOffsetX =
+        createConstantIndex(rewriter, loc, distUnitBaseAddr[0]);
+    Value constOffsetY =
+        createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
+
+    // Compute offsets within entire tile
+    Value offsetX =
+        rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
+    Value offsetY =
+        rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
+
+    // Add to global offsets
+    size_t lastDimIndex = originalOffsets.size() - 1;
+    size_t secondLastDimIndex = lastDimIndex - 1;
+
+    Value globalOffsetX = rewriter.createOrFold<index::AddOp>(
+        loc, originalOffsets[secondLastDimIndex], offsetX);
+    Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
+        loc, originalOffsets[lastDimIndex], offsetY);
+
+    // Create final offset list
+    SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
+                                            originalOffsets.end());
+    globalOffsets[secondLastDimIndex] = globalOffsetX;
+    globalOffsets[lastDimIndex] = globalOffsetY;
+
+    return globalOffsets;
+  }
+
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *ctx = op.getContext();
+    xegpu::TensorDescType tdescTy = op.getType();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+    Type elemTy = tdescTy.getElementType();
+    ArrayRef<int64_t> wgShape = tdescTy.getShape();
+    ArrayRef<int64_t> sgShape =
+        llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
+    ArrayRef<int64_t> sgLayout =
+        llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
+
+    // Get the subgroup ID
+    auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
+
+    // Create constants for layout dimensions
+    SmallVector<Value> sgLayoutDim(sgLayout.size());
+    SmallVector<Value> sgDataDim(sgShape.size());
+
+    for (size_t i = 0; i < sgLayout.size(); i++) {
+      sgLayoutDim[i] = createConstantIndex(rewriter, loc, sgLayout[i]);
+      sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
+    }
+
+    // Delinearize the 1D subgroup id into nd coordinates
+    SmallVector<Value> sgIds = delinearizeSubgroupId(
+        rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
+
+    // Calculate distribution unit shape and local offsets for subgroup
+    SmallVector<int64_t> distUnitShape(sgLayout.size());
+    SmallVector<Value> localOffset(sgLayout.size());
+    for (size_t i = 0; i < sgLayout.size(); i++) {
+      distUnitShape[i] = sgLayout[i] * sgShape[i];
+      localOffset[i] =
+          rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
+    }
+
+    SmallVector<Value> originalOffsets = extractOffsets(rewriter, op);
+
+    xegpu::TensorDescType newTdescTy =
+        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+                                   layout.dropSgLayoutAndData());
+    SmallVector<Value> newCreateNdOps;
+    for (const SmallVector<int64_t> &distUnitBaseAddr :
+         StaticTileOffsetRange(wgShape, distUnitShape)) {
+      SmallVector<OpFoldResult> globalOffsets = calculateGlobalOffsets(
+          rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr);
+
+      auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
+          loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
+          op.getMixedStrides());
+      newCreateNdOps.push_back(newCreateNdOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+    return success();
+  }
+};
+
+/// This pattern transforms the LoadNdOp to load from a subgroup descriptor
+/// It creates a LoadNdOp op to load the new subgroup src tensor descriptors.
+struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> newLoadOps;
+    for (auto src : adaptor.getTensorDesc()) {
+      xegpu::TensorDescType tdescTy =
+          dyn_cast<xegpu::TensorDescType>(src.getType());
+      ArrayRef<int64_t> srcShape = tdescTy.getShape();
+      VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
+      auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(op.getLoc(), newResTy,
+                                                        src, op->getAttrs());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return mlir::success();
+  }
+};
+
+/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
+/// It creates a StoreNdOp op to store the updated values to the new subgroup
+/// src tensor descriptors.
+struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
+  using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
+      rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(),
+                                        op.getL2HintAttr(), op.getL3HintAttr());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
+/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
+/// offsets of the new subgroup src tensor descriptors.
+struct WgToSgUpdateNdOffsetOp
+    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+  using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    llvm::SmallVector<Value> newUpdateTileOffsetOps;
+    for (auto tDesc : adaptor.getTensorDesc()) {
+      auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+          op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
+          op.getConstOffsets());
+      newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
+    return success();
+  }
+};
+
+/// This pattern transforms the DpasOp to work at subgroup level.
+struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
+  using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType resultTy = op.getResult().getType();
+    if (resultTy.getRank() != 2)
+      return failure();
+
+    auto originalLayout =
+        llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    if (!originalLayout)
+      return failure();
+
+    SmallVector<Value> newDpasOps;
+    size_t i = 0;
+    for (auto aVec : adaptor.getLhs()) {
+      for (auto bVec : adaptor.getRhs()) {
+
+        llvm::SmallVector<Value> operands({aVec, bVec});
+        Value tmpC;
+        if (op.getAcc()) {
+          tmpC = adaptor.getAcc()[i++];
+          operands.push_back(tmpC);
+        }
+
+        ArrayRef<int64_t> aVecShape =
+            llvm::cast<VectorType>(aVec.getType()).getShape();
+        ArrayRef<int64_t> bVecShape =
+            llvm::cast<VectorType>(bVec.getType()).getShape();
+        VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
+                                           resultTy.getElementType());
+        tmpC = rewriter.create<xegpu::DpasOp>(
+            loc, resTy, operands,
+            llvm::ArrayRef<NamedAttribute>(
+                {"layout", originalLayout.dropSgLayoutAndData()}));
+        newDpasOps.push_back(tmpC);
+      }
+    }
+    rewriter.replaceOpWithMultiple(op, {newDpasOps});
+    return mlir::success();
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace xegpu {
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
+  patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp>(patterns.getContext());
+}
+} // namespace xegpu
+} // namespace mlir
+
+namespace {
+struct XeGPUWgToSgPass : public xegpu::impl::XeGPUWgToSgBase<XeGPUWgToSgPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void XeGPUWgToSgPass::runOnOperation() {
+  MLIRContext *ctx = &getContext();
+  RewritePatternSet patterns(ctx);
+  ConversionTarget target(*ctx);
+
+  auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
+    if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
+      return createOp.getType();
+    if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
+      return loadOp.getTensorDescType();
+    if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
+      return storeOp.getTensorDescType();
+    if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
+      return updateOp.getType();
+    return xegpu::TensorDescType();
+  };
+
+  auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
+    return !layout || layout.getSgLayout() == nullptr;
+  };
+
+  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
+                               xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp>(
+      [=](Operation *op) -> bool {
+        auto tdescTy = getTensorDescType(op);
+        auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
+        return isLegal(layout);
+      });
+
+  target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
+    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    return isLegal(layout);
+  });
+
+  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+  xegpu::populateXeGPUWgToSgPatterns(patterns);
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    return signalPassFailure();
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
new file mode 100644
index 0000000000000..d0f225c3e7304
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+
+gpu.module @test_round_robin_assignment {
+  // CHECK: test_create_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
+      // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      gpu.return
+    }
+
+  // CHECK: test_load_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-COUNT-12: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<2x2xf32>
+      %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+      gpu.return
+    }
+
+  // CHECK: test_store_nd
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+      xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      gpu.return
+  }
+
+  // CHECK: test_update_nd
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_update_nd(%src: memref<24x32xf32>){
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-12: %[[UPDATE:.*]] = xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    %update = xegpu.update_nd_offset %tdesc, [0, 16] :  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    gpu.return
+  }
+
+  // CHECK: test_dpas
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
+  // CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
+  gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
+    // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}},
+    // %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32,
+    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-12:
+    // %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] :
+    // memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32,
+    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-9:
+    // %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] :
+    // memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32,
+    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-144:
+    // %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout =
+    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} :
+    // vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b:  !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
+    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x24xf32> -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+    gpu.return
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
new file mode 100644
index 0000000000000..c4c8881e65597
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+
+gpu.module @test_1_1_assignment {
+  // CHECK: test_create_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {  
+  // CHECK: %[[SGID:.*]] = gpu.subgroup_id
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK: %[[DIV:.*]] = index.divu %[[SGID]], %[[C4]]
+  // CHECK: %[[REM:.*]] = index.remu %[[SGID]], %[[C4]]
+  // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
+  // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[ADD1:.*]] = index.add %[[MUL1]], %[[C0]]
+  // CHECK: %[[ADD2:.*]] = index.add %[[MUL2]], %[[C0]]
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: gpu.return
+  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  gpu.return
+  }
+
+  // CHECK: test_load_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+    gpu.return
+  }
+
+  // CHECK: test_store_nd
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+    // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+    xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    gpu.return
+}
+
+// CHECK: test_update_nd
+// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+gpu.func @test_update_nd(%src: memref<24x32xf32>){
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  gpu.return
+}
+
+// CHECK: test_dpas
+// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
+gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
+    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}},
+    // {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32,
+    // #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> CHECK:
+    // %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] :
+    // !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8],
+    // lane_data = [1, 1]>> -> vector<12x8xf32> CHECK: %[[TDESC_B:.*]] =
+    // xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> ->
+    // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
+    // lane_data = [1, 1]>> CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] :
+    // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
+    // lane_data = [1, 1]>> -> vector<8x12xf32> CHECK: %[[DPAS:.*]] = xegpu.dpas
+    // %[[LOAD_A]], %[[LOAD_B]] {layout =  #xegpu.layout<lane_layout = [2, 2],
+    // lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> ->
+    // vector<12x12xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
+    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+    gpu.return
+  }
+}

>From b3bf12f082eb08aa3f82503142140fc686e0e950 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sun, 11 May 2025 15:49:35 +0000
Subject: [PATCH 02/55] Add prefetch_nd op

---
 .../Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp  | 52 ++++++++++++-------
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 25 ++++-----
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 29 ++++++-----
 3 files changed, 60 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 7969d37d67f04..5eabb04e3b858 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -34,11 +34,10 @@ using namespace mlir;
 namespace {
 
 // clang-format off
-/// This pattern transform the CreateNdDescOp to create a subgroup descriptor
+/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
-/// It uses round-robin distribution to create the subgroup descriptor.
-
+/// It uses round-robin assignment to distribute the work to the subgroups.
 /// Following create_nd_desc operation:,
 ///    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
 ///       -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
@@ -47,7 +46,7 @@ namespace {
 ///    %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
 ///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
 ///
-/// The sg_layout and sg_data are dropped from the layout attribute as they are no longer needed.
+/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
 ///
 /// 24x24 matrix distribution example:
 /// sg_layout = [4, 4], sg_data = [2, 2]
@@ -72,7 +71,6 @@ namespace {
 ///
 /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
 /// distribution units (3x3) in total. Hence the 9 subgroup level operations.
-/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
 // clang-format on
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
@@ -110,7 +108,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     return rewriter.create<arith::ConstantIndexOp>(loc, value);
   }
 
-  // Calculate global offset for each subgroup
+  // Calculate offset for each subgroup
   SmallVector<OpFoldResult>
   calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
                          const SmallVector<Value> &originalOffsets,
@@ -122,13 +120,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     Value constOffsetY =
         createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
 
-    // Compute offsets within entire tile
     Value offsetX =
         rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
     Value offsetY =
         rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
 
-    // Add to global offsets
     size_t lastDimIndex = originalOffsets.size() - 1;
     size_t secondLastDimIndex = lastDimIndex - 1;
 
@@ -137,7 +133,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
         loc, originalOffsets[lastDimIndex], offsetY);
 
-    // Create final offset list
     SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
                                             originalOffsets.end());
     globalOffsets[secondLastDimIndex] = globalOffsetX;
@@ -172,7 +167,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
     }
 
-    // Delinearize the 1D subgroup id into nd coordinates
+    // Delinearize the 1D subgroup id into 2d
     SmallVector<Value> sgIds = delinearizeSubgroupId(
         rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
 
@@ -207,8 +202,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   }
 };
 
-/// This pattern transforms the LoadNdOp to load from a subgroup descriptor
-/// It creates a LoadNdOp op to load the new subgroup src tensor descriptors.
+/// This pattern transforms the LoadNdOp to load subgroup data.
 struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
   using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
   LogicalResult
@@ -310,7 +304,22 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
       }
     }
     rewriter.replaceOpWithMultiple(op, {newDpasOps});
-    return mlir::success();
+    return success();
+  }
+};
+
+/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
+struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
+  using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    for (auto src : adaptor.getTensorDesc()) {
+      rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
+                                           op->getAttrs());
+    }
+    rewriter.eraseOp(op);
+    return success();
   }
 };
 
@@ -320,7 +329,8 @@ namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp>(patterns.getContext());
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
+      patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -345,6 +355,8 @@ void XeGPUWgToSgPass::runOnOperation() {
       return storeOp.getTensorDescType();
     if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
       return updateOp.getType();
+    if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
+      return prefetchOp.getTensorDescType();
     return xegpu::TensorDescType();
   };
 
@@ -353,12 +365,12 @@ void XeGPUWgToSgPass::runOnOperation() {
   };
 
   target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
-                               xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp>(
-      [=](Operation *op) -> bool {
-        auto tdescTy = getTensorDescType(op);
-        auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
-        return isLegal(layout);
-      });
+                               xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
+                               xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
+    auto tdescTy = getTensorDescType(op);
+    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
+    return isLegal(layout);
+  });
 
   target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
     auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d0f225c3e7304..de2c548ec7ebb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -42,18 +42,10 @@ gpu.module @test_round_robin_assignment {
   // CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
   // CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
   gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
-    // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}},
-    // %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32,
-    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-12:
-    // %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] :
-    // memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32,
-    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-9:
-    // %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] :
-    // memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32,
-    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-144:
-    // %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout =
-    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} :
-    // vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+    // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-12: %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-9: %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-144: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
     %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
@@ -62,4 +54,13 @@ gpu.module @test_round_robin_assignment {
     %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
+
+  // CHECK: test_prefetch_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+    // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index c4c8881e65597..1cae2c822d826 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -57,20 +57,11 @@ gpu.func @test_update_nd(%src: memref<24x32xf32>){
 // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
 // CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
 gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
-    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}},
-    // {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32,
-    // #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> CHECK:
-    // %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] :
-    // !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8],
-    // lane_data = [1, 1]>> -> vector<12x8xf32> CHECK: %[[TDESC_B:.*]] =
-    // xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> ->
-    // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
-    // lane_data = [1, 1]>> CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] :
-    // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
-    // lane_data = [1, 1]>> -> vector<8x12xf32> CHECK: %[[DPAS:.*]] = xegpu.dpas
-    // %[[LOAD_A]], %[[LOAD_B]] {layout =  #xegpu.layout<lane_layout = [2, 2],
-    // lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> ->
-    // vector<12x12xf32>
+    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> 
+    // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+    // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> 
+    // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<8x12xf32>
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
     %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
@@ -78,4 +69,14 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
+
+  // CHECK: test_prefetch_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: xegpu.prefetch_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    gpu.return
+  }
 }

>From 6a8647fa764e710f5aaeb51b46ae2ea398a959a3 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sun, 11 May 2025 22:06:24 +0000
Subject: [PATCH 03/55] Remove braces for single statement for and if

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 5eabb04e3b858..836f307ece9e1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -83,12 +83,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     auto dynamicOffsets = op.getOffsets();
 
     for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
-      if (ShapedType::isDynamic(staticOffsets[i])) {
+      if (ShapedType::isDynamic(staticOffsets[i]))
         offsets.push_back(dynamicOffsets[j++]);
-      } else {
+      else
         offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
             op.getLoc(), staticOffsets[i]));
-      }
     }
     return offsets;
   }
@@ -314,10 +313,9 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   LogicalResult
   matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    for (auto src : adaptor.getTensorDesc()) {
+    for (auto src : adaptor.getTensorDesc())
       rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
                                            op->getAttrs());
-    }
     rewriter.eraseOp(op);
     return success();
   }

>From c6589299e4e1375e91b61fbf0edb9f3d1f7a89c4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 12 May 2025 14:30:54 +0000
Subject: [PATCH 04/55] Clean up

---
 mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h | 5 +----
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp       | 3 +++
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 388ba32e1eebb..5c12973edbed8 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,8 +9,6 @@
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 
-#include "mlir/Transforms/DialectConversion.h"
-
 namespace mlir {
 class RewritePatternSet;
 
@@ -20,8 +18,7 @@ namespace xegpu {
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
-void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns,
-                                 ConversionTarget &target);
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 836f307ece9e1..512cdca251c42 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -147,6 +147,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
     auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+    if (!layout)
+      return failure();
     Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
     ArrayRef<int64_t> sgShape =
@@ -154,6 +156,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     ArrayRef<int64_t> sgLayout =
         llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
 
+    // TODO : Handle order attribute
     // Get the subgroup ID
     auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
 

>From 777a403f896d811dbe36a7aed6ccacf6adf9c833 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 12 May 2025 19:36:58 +0000
Subject: [PATCH 05/55] add utils

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     | 15 +++++++
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 27 +++++--------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 40 +++++++++++++++++++
 3 files changed, 64 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 3616fa614e7f9..5c2a308887040 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -13,6 +13,9 @@
 namespace mlir {
 
 class VectorType;
+class OpOperand;
+class OpResult;
+
 namespace xegpu {
 class LayoutAttr;
 class TensorDescType;
@@ -50,6 +53,18 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
 FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
                                                LayoutAttr layout);
 
+/// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
+/// values, the LayoutAttr is extracted from the TensorDescType itself. For
+/// other values, it is obtained from the attributes of the defining operation.
+/// Returns nullptr if no LayoutAttr is found.
+LayoutAttr getLayoutAttr(Value value);
+
+/// Retrieves the name for the LayoutAttr associated with a given OpOperand.
+std::string getLayoutName(OpOperand &opr);
+
+/// Retrieves the name for the LayoutAttr associated with a given OpResult.
+std::string getLayoutName(OpResult res);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 2300d9e3bd43f..ca887bd0fb7b5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -62,8 +62,6 @@ constexpr unsigned packedSizeInBitsForDefault =
     16; // Minimum packing size per register for DPAS A.
 constexpr unsigned packedSizeInBitsForDpasB =
     32; // Minimum packing size per register for DPAS B.
-static const char *const operandLayoutNamePrefix = "layout_operand_";
-static const char *const resultLayoutNamePrefix = "layout_result_";
 
 namespace {
 
@@ -728,10 +726,7 @@ class LayoutAttrAssignment {
 void LayoutAttrAssignment::assignToUsers(Value v, xegpu::LayoutAttr layout) {
   for (OpOperand &user : v.getUses()) {
     Operation *owner = user.getOwner();
-    unsigned operandNumber = user.getOperandNumber();
-    // Use a generic name for ease of querying the layout attribute later.
-    std::string attrName =
-        operandLayoutNamePrefix + std::to_string(operandNumber);
+    std::string attrName = xegpu::getLayoutName(user);
     owner->setAttr(attrName, layout);
   }
 }
@@ -805,10 +800,10 @@ LogicalResult LayoutAttrAssignment::assign(Operation *op) {
     return success();
   }
   // Otherwise simply attach the layout to the op itself.
-  for (auto [i, r] : llvm::enumerate(op->getResults())) {
+  for (auto r : op->getOpResults()) {
     xegpu::LayoutAttr layoutInfo = getLayoutAttrForValue(r);
     if (layoutInfo) {
-      std::string attrName = resultLayoutNamePrefix + std::to_string(i);
+      std::string attrName = xegpu::getLayoutName(r);
       op->setAttr(attrName, layoutInfo);
       // Attach the layout attribute to the users of the result.
       assignToUsers(r, layoutInfo);
@@ -928,11 +923,8 @@ static SmallVector<NamedAttribute>
 removeTemporaryLayoutAttributes(ArrayRef<NamedAttribute> attrs) {
   SmallVector<NamedAttribute> newAttrs;
   for (NamedAttribute attr : attrs) {
-    if (attr.getName().strref().contains(operandLayoutNamePrefix) ||
-        attr.getName().strref().contains(resultLayoutNamePrefix)) {
-      continue;
-    }
-    newAttrs.push_back(attr);
+    if (!isa<xegpu::LayoutAttr>(attr.getValue()))
+      newAttrs.push_back(attr);
   }
   return newAttrs;
 }
@@ -1335,11 +1327,10 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
 
     auto dpasOp = operand->get().getDefiningOp<xegpu::DpasOp>();
     unsigned operandIdx = operand->getOperandNumber();
-    std::string layoutAName =
-        llvm::formatv("{0}{1}", operandLayoutNamePrefix, 0).str();
-    std::string layoutBName =
-        llvm::formatv("{0}{1}", operandLayoutNamePrefix, 1).str();
-    auto layoutCName = llvm::formatv("{0}{1}", resultLayoutNamePrefix, 0).str();
+    std::string layoutAName = xegpu::getLayoutName(dpasOp->getOpOperand(0));
+    std::string layoutBName = xegpu::getLayoutName(dpasOp->getOpOperand(1));
+    std::string layoutCName = xegpu::getLayoutName(dpasOp->getOpResult(0));
+
     xegpu::LayoutAttr layoutA =
         dpasOp->getAttrOfType<xegpu::LayoutAttr>(layoutAName);
     xegpu::LayoutAttr layoutB =
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 6b45ed0ae4ced..d101ce07043ec 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -12,6 +12,8 @@
 
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
 
@@ -83,3 +85,41 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
       /*memory_space=*/xegpu::MemorySpace::Global, layout);
   return xegpu::getDistributedVectorType(helperTdescTy);
 }
+
+xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
+  if (!value)
+    return LayoutAttr();
+
+  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(value.getType()))
+    return tdescTy.getLayoutAttr();
+
+  if (auto result = dyn_cast<OpResult>(value)) {
+    Operation *defOp = result.getDefiningOp();
+    assert(defOp && "result must have a defining op");
+    std::string layoutName = getLayoutName(result);
+    if (defOp->hasAttr(layoutName))
+      return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+  }
+
+  if (auto arg = dyn_cast<BlockArgument>(value)) {
+    auto parentOp = arg.getOwner()->getParentOp();
+    if (auto funcOp = dyn_cast<FuncOp>(parentOp)) {
+      std::string layoutName = getLayoutName(arg);
+      if (funcOp->hasAttr(layoutName))
+        return funcOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+    }
+  }
+
+  return nullptr;
+}
+
+std::string xegpu::getLayoutName(OpOperand &opr) {
+  const StringRef prefix("layout_operand_");
+  return llvm::formatv("{0}{1}", prefix, opr.getOperandNumber()).str();
+}
+
+std::string xegpu::getLayoutName(OpResult res) {
+  const StringRef prefix = "layout_result_";
+  return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
+}
+

>From af01c99481e1a88fef78b2517cf9b2f531acbd9f Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 12 May 2025 19:37:07 +0000
Subject: [PATCH 06/55] add skeleton

---
 mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td | 12 ++++++++++++
 mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt     |  1 +
 2 files changed, 13 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 3e81f2d0ed786..54782933fe5f8 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -38,4 +38,16 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   ];
 }
 
+def XeGPUInstructionlize: Pass<"xegpu-instructionlize"> {
+  let summary = "Instructionlize XeGPU ops";
+  let description = [{
+    The pass unrolls XeGPU ops working on large shapes into ops working on small shapes
+    (given by the inst_data in the layout attr), such that each of them can be dispatch
+    into a hardware instruction.
+  }];
+  let dependentDialects = [
+      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 892eb791c46e7..1d94b4c4c03ac 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
+  XeGPUInstructionlize.cpp
   XeGPUSubgroupDistribute.cpp
   XeGPUUnroll.cpp
 

>From e8b43fbfe2b3764dc804b13975154b0f584c7d9b Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 13 May 2025 00:44:02 +0000
Subject: [PATCH 07/55] add filter

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td |  4 ++++
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp      | 16 ++++++++++------
 2 files changed, 14 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 032ce5bc18334..3f5fe2cce4636 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -295,11 +295,15 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
     }
 
     LayoutAttr dropSgLayoutAndData() {
+      if (!getInstData() && !getLaneLayout())
+        return nullptr;
       return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
                              getLaneLayout(), getLaneData(), getOrder());
     }
 
     LayoutAttr dropInstData() {
+      if (!getSgLayout() && !getLaneLayout())
+        return nullptr;
       return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
                              getLaneLayout(), getLaneData(), getOrder());
     }
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index d101ce07043ec..285a15062e402 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
@@ -88,7 +89,7 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
 
 xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
   if (!value)
-    return LayoutAttr();
+    return nullptr;
 
   if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(value.getType()))
     return tdescTy.getLayoutAttr();
@@ -96,6 +97,11 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
   if (auto result = dyn_cast<OpResult>(value)) {
     Operation *defOp = result.getDefiningOp();
     assert(defOp && "result must have a defining op");
+
+    // for LoadNdOp, the layout is stored in the tensor descriptor
+    if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
+      return getLayoutAttr(loadNd.getTensorDesc());
+
     std::string layoutName = getLayoutName(result);
     if (defOp->hasAttr(layoutName))
       return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
@@ -103,10 +109,9 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
 
   if (auto arg = dyn_cast<BlockArgument>(value)) {
     auto parentOp = arg.getOwner()->getParentOp();
-    if (auto funcOp = dyn_cast<FuncOp>(parentOp)) {
-      std::string layoutName = getLayoutName(arg);
-      if (funcOp->hasAttr(layoutName))
-        return funcOp->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+    if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
+      OpOperand *tiedInit = loop.getTiedLoopInit(arg);
+      return getLayoutAttr(tiedInit->get());
     }
   }
 
@@ -122,4 +127,3 @@ std::string xegpu::getLayoutName(OpResult res) {
   const StringRef prefix = "layout_result_";
   return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
 }
-

>From 3f73fda71e833ef844eec19bd2eda0f3b6b31020 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 13 May 2025 01:06:29 +0000
Subject: [PATCH 08/55] clean up

---
 .../XeGPU/Transforms/XeGPUInstructionlize.cpp | 143 ++++++++++++++++++
 1 file changed, 143 insertions(+)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
new file mode 100644
index 0000000000000..b83ce86a357f0
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
@@ -0,0 +1,143 @@
+//===---- XeGPUInstructionlize.cpp -- XeGPU Instructionlize Pass ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUINSTRUCTIONLIZE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-instructionlize"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+using namespace mlir;
+
+namespace {
+
+/// Unroll XeGPU ops to their instruction-level representation.
+class XeGPUInstructionlizePass final
+    : public xegpu::impl::XeGPUInstructionlizeBase<XeGPUInstructionlizePass> {
+public:
+  void runOnOperation() override;
+
+private:
+  SmallVector<int64_t> getTileShape(TypedValue<ShapedType> value) const;
+  std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
+  bool needsUnroll(Operation *op) const;
+};
+} // namespace
+
+SmallVector<int64_t>
+XeGPUInstructionlizePass::getTileShape(TypedValue<ShapedType> value) const {
+  assert(value && "value must be non-null");
+  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value);
+  if (layout && layout.isSgLayout()) {
+    if (auto inst_data = layout.getInstData())
+      return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+  }
+  return llvm::to_vector(value.getType().getShape());
+}
+
+std::optional<SmallVector<int64_t>>
+XeGPUInstructionlizePass::getTileShape(Operation *op) const {
+  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
+    return getTileShape(cast<TypedValue<ShapedType>>(op->getResult(0)));
+  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
+    return getTileShape(cast<TypedValue<ShapedType>>(op->getOperand(0)));
+  if (isa<xegpu::StoreNdOp>(op))
+    return getTileShape(cast<TypedValue<ShapedType>>(op->getOperand(1)));
+
+  if (isa<xegpu::DpasOp>(op)) {
+    auto a = cast<TypedValue<ShapedType>>(op->getOperand(0));
+    auto b = cast<TypedValue<ShapedType>>(op->getOperand(1));
+    SmallVector<int64_t> aTileShape = getTileShape(a);
+    SmallVector<int64_t> bTileShape = getTileShape(b);
+
+    if (aTileShape.size() != 2 || bTileShape.size() != 2)
+      return std::nullopt;
+
+    // semantic check for A and B
+    if (aTileShape[1] != bTileShape[0])
+      return std::nullopt;
+
+    // semantic check for C
+    if (op->getNumOperands() == 3) {
+      auto c = cast<TypedValue<ShapedType>>(op->getOperand(2));
+      SmallVector<int64_t> cTileShape = getTileShape(c);
+      int64_t expectedShape[2] = {aTileShape[0], bTileShape[1]};
+      if (!llvm::equal(cTileShape, expectedShape))
+        return std::nullopt;
+    }
+
+    return SmallVector<int64_t>({aTileShape[0], aTileShape[1], bTileShape[1]});
+  }
+  return std::nullopt;
+}
+
+bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
+  for (Value opr : op->getOperands()) {
+    if (auto value = dyn_cast<TypedValue<ShapedType>>(opr)) {
+      auto tileShape = getTileShape(value);
+      // the tile should have the same rank as the origial type
+      if (tileShape.size() != static_cast<size_t>(value.getType().getRank()))
+        return false;
+      if (!llvm::equal(tileShape, value.getType().getShape()))
+        return true;
+    }
+  }
+  return false;
+}
+
+void XeGPUInstructionlizePass::runOnOperation() {
+  MLIRContext *ctx = &getContext();
+  xegpu::UnrollOptions options;
+  options.setFilterConstraint([&](Operation *op) -> LogicalResult {
+    return needsUnroll(op) ? success() : failure();
+  });
+
+  options.setNativeShapeFn(
+      [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
+        return getTileShape(op);
+      });
+
+  options.setUnrolledTypesFn(
+      [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
+        Type elemTy = type.getElementType();
+        Type newTy;
+
+        if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
+          newTy = xegpu::TensorDescType::get(
+              ctx, tileShape, elemTy, tdescTy.getEncoding(),
+              tdescTy.getLayoutAttr().dropInstData());
+        else
+          newTy = type.clone(tileShape, elemTy);
+
+        std::optional<SmallVector<int64_t>> ratio =
+            computeShapeRatio(type.getShape(), tileShape);
+        assert(ratio &&
+               "The shape of the type must be a multiple of tileShape.");
+        return SmallVector<Type>(computeProduct(*ratio), newTy);
+      });
+
+  RewritePatternSet patterns(ctx);
+
+  populateXeGPUUnrollPatterns(patterns, options);
+  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+}

>From ab448a34294bf2333af8ed52e6d4db540706d20f Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 13 May 2025 18:45:16 +0000
Subject: [PATCH 09/55] add scf type conversion util

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |   5 +
 .../XeGPU/Transforms/XeGPUInstructionlize.cpp |  41 ++--
 mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt   |   1 +
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 182 ++++++++++++++++++
 4 files changed, 215 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 5c2a308887040..4bcda3e3ac95f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -65,6 +65,11 @@ std::string getLayoutName(OpOperand &opr);
 /// Retrieves the name for the LayoutAttr associated with a given OpResult.
 std::string getLayoutName(OpResult res);
 
+/// Do type conversion for SCF structural ops, e.g., scf.for. Since VectorType
+/// cannot carry the layout attribute, they are converted into RankedTensorType
+/// first, which will convert back to VectorType in the second round.
+void doSCFStructuralTypeConversionWithTensorType(Operation *op);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
index b83ce86a357f0..efc44aadb14e6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
@@ -38,21 +38,33 @@ class XeGPUInstructionlizePass final
   void runOnOperation() override;
 
 private:
-  SmallVector<int64_t> getTileShape(TypedValue<ShapedType> value) const;
+  // Get the tile shape for a given value. If the value has a layout
+  // attribute and it is an SG layout, return the inst_data as the tile shape
+  // if inst_data is available; otherwise, return the original shape of the
+  // value. If the value does not have an SG layout, return std::nullopt.
+  std::optional<SmallVector<int64_t>>
+  getTileShape(TypedValue<ShapedType> value) const;
+
+  // Get the tile shape for a given operation.
   std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
+
+  // Determine if the operation requires unrolling. Return false if all operands
+  // and results have tile shapes identical to their original types. Otherwise,
+  // return true.
   bool needsUnroll(Operation *op) const;
 };
 } // namespace
 
-SmallVector<int64_t>
+std::optional<SmallVector<int64_t>>
 XeGPUInstructionlizePass::getTileShape(TypedValue<ShapedType> value) const {
   assert(value && "value must be non-null");
   xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value);
   if (layout && layout.isSgLayout()) {
     if (auto inst_data = layout.getInstData())
       return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+    return llvm::to_vector(value.getType().getShape());
   }
-  return llvm::to_vector(value.getType().getShape());
+  return std::nullopt;
 }
 
 std::optional<SmallVector<int64_t>>
@@ -67,26 +79,26 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
   if (isa<xegpu::DpasOp>(op)) {
     auto a = cast<TypedValue<ShapedType>>(op->getOperand(0));
     auto b = cast<TypedValue<ShapedType>>(op->getOperand(1));
-    SmallVector<int64_t> aTileShape = getTileShape(a);
-    SmallVector<int64_t> bTileShape = getTileShape(b);
+    std::optional<SmallVector<int64_t>> aTile = getTileShape(a);
+    std::optional<SmallVector<int64_t>> bTile = getTileShape(b);
 
-    if (aTileShape.size() != 2 || bTileShape.size() != 2)
+    if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
       return std::nullopt;
 
     // semantic check for A and B
-    if (aTileShape[1] != bTileShape[0])
+    if ((*aTile)[1] != (*bTile)[0])
       return std::nullopt;
 
     // semantic check for C
     if (op->getNumOperands() == 3) {
       auto c = cast<TypedValue<ShapedType>>(op->getOperand(2));
-      SmallVector<int64_t> cTileShape = getTileShape(c);
-      int64_t expectedShape[2] = {aTileShape[0], bTileShape[1]};
-      if (!llvm::equal(cTileShape, expectedShape))
+      std::optional<SmallVector<int64_t>> cTile = getTileShape(c);
+      int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
+      if (!cTile || !llvm::equal(*cTile, expectedCTile))
         return std::nullopt;
     }
 
-    return SmallVector<int64_t>({aTileShape[0], aTileShape[1], bTileShape[1]});
+    return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
   }
   return std::nullopt;
 }
@@ -94,11 +106,12 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
 bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
   for (Value opr : op->getOperands()) {
     if (auto value = dyn_cast<TypedValue<ShapedType>>(opr)) {
-      auto tileShape = getTileShape(value);
+      std::optional<SmallVector<int64_t>> tileShape = getTileShape(value);
       // the tile should have the same rank as the origial type
-      if (tileShape.size() != static_cast<size_t>(value.getType().getRank()))
+      if (!tileShape ||
+          tileShape->size() != static_cast<size_t>(value.getType().getRank()))
         return false;
-      if (!llvm::equal(tileShape, value.getType().getShape()))
+      if (!llvm::equal(*tileShape, value.getType().getShape()))
         return true;
     }
   }
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index afd8e2d5c4df3..98e84a4420722 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -6,5 +6,6 @@ add_mlir_dialect_library(MLIRXeGPUUtils
 
   LINK_LIBS PUBLIC
   MLIRIR
+  MLIRSCFTransforms
   MLIRXeGPUDialect
   )
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 285a15062e402..e43aac4ce8dc0 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -11,9 +11,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
@@ -127,3 +130,182 @@ std::string xegpu::getLayoutName(OpResult res) {
   const StringRef prefix = "layout_result_";
   return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
 }
+
+void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
+  MLIRContext *context = op->getContext();
+
+  auto materializeCast = [&](OpBuilder &builder, Type type, ValueRange inputs,
+                             Location loc) -> Value {
+    return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+        .getResult(0);
+  };
+
+  { // convert VectorType to RankedTensorType for SCF Structural ops
+    TypeConverter converter;
+    converter.addConversion([&](Type type) -> Type { return type; });
+    converter.addConversion([&](VectorType type) -> Type {
+      return RankedTensorType::get(type.getShape(), type.getElementType());
+    });
+    converter.addSourceMaterialization(materializeCast);
+    converter.addTargetMaterialization(materializeCast);
+
+    mlir::ConversionTarget target(*context);
+    target.addLegalOp<UnrealizedConversionCastOp>();
+
+    mlir::RewritePatternSet patterns(context);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    (void)mlir::applyPartialConversion(op, target, std::move(patterns));
+  }
+
+  { // propagate the layout attribute to RankedTensorType by checking
+    // BuiltInUnrealizedCastOps
+    // for VectorType to RankedTensorType cast.
+    op->walk([&](UnrealizedConversionCastOp castOp) {
+      if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
+        return WalkResult::skip();
+
+      Value input = castOp.getInputs()[0];
+      Value result = castOp.getResults()[0];
+      auto inputTy = dyn_cast<VectorType>(input.getType());
+      auto resultTy = dyn_cast<RankedTensorType>(result.getType());
+
+      // Only look at ops casting from VectorType to RankedTensorType
+      if (!isa<VectorType>(inputTy) || !isa<RankedTensorType>(resultTy))
+        return WalkResult::skip();
+
+      xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input);
+      if (!layout)
+        return WalkResult::skip();
+
+      RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
+      result.setType(newTy);
+
+      // update the arguments if user is a LoopLike op.
+      for (OpOperand &use : result.getUses()) {
+        if (auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
+          BlockArgument arg = loop.getTiedLoopRegionIterArg(&use);
+          arg.setType(newTy);
+        }
+        // whileOp has two regions, the BlockArgument of the after region
+        // is not exposed by LoopLikeOpInterface
+        if (auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
+          unsigned idx = use.getOperandNumber();
+          BlockArgument arg = whileOp.getAfterArguments()[idx];
+          arg.setType(newTy);
+        }
+      }
+      return WalkResult::advance();
+    });
+
+    // using yieldOp as anchor to update the result type of its ParentOp
+    op->walk([&](scf::YieldOp yieldOp) {
+      Operation *parentOp = yieldOp->getParentOp();
+      for (OpResult r : parentOp->getOpResults()) {
+        unsigned idx = r.getResultNumber();
+        Type resultTy = r.getType();
+        Type yieldTy = yieldOp.getResults()[idx].getType();
+        if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
+          r.setType(yieldTy);
+      }
+    });
+  }
+
+  { // perform the conversion from RankedTensorType to VectorType based on the
+    // LayoutAttr
+
+    auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape,
+                                        DenseI32ArrayAttr sgDataAttr,
+                                        DenseI32ArrayAttr sgLayoutAttr) {
+      SmallVector<int64_t> tileShape;
+      auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+      if (sgDataAttr)
+        tileShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+      else
+        tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape);
+      assert(tileShape.size() && "failed to compute tileShape");
+      SmallVector<int64_t> distUnit =
+          computeElementwiseMul(sgLayout, tileShape);
+      int count = computeProduct(shape) / computeProduct(distUnit);
+      return std::make_pair(tileShape, count);
+    };
+
+    TypeConverter converter;
+    converter.addConversion([&](Type type) -> Type { return type; });
+    converter.addConversion(
+        [&](RankedTensorType type,
+            SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+          ArrayRef<int64_t> shape = type.getShape();
+          auto encoding = type.getEncoding();
+          Type elemTy = type.getElementType();
+
+          // init count and subShape to the default value. If the LayoutAttr
+          // is not present, it will return a VectorType with original shape.
+          int count = 1;
+          SmallVector<int64_t> subShape(shape);
+
+          if (auto layout =
+                  llvm::dyn_cast_if_present<xegpu::LayoutAttr>(encoding)) {
+            if (layout.isWgLayout()) {
+              // for WgToSg, the subShape is either from sgData or computed as
+              // shape/sgLayout
+              std::tie(subShape, count) = computeTileShapeAndCount(
+                  shape, layout.getSgData(), layout.getSgLayout());
+            } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
+              // for unrolling, the subShape is determined by inst_data
+              subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
+              count = computeProduct(shape) / computeProduct(subShape);
+            }
+          }
+          auto newTy = VectorType::get(subShape, elemTy);
+          result.append(count, newTy);
+          return success();
+        });
+
+    converter.addConversion(
+        [&](xegpu::TensorDescType type,
+            SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+          MLIRContext *ctx = type.getContext();
+          Type elemTy = type.getElementType();
+          Attribute encoding = type.getEncoding();
+          ArrayRef<int64_t> shape = type.getShape();
+
+          // init count and newTy to the default value. If the layout attribute
+          // is not present, it will return the original type.
+          int count = 1;
+          Type newTy = type;
+
+          if (xegpu::LayoutAttr layout = type.getLayoutAttr()) {
+            SmallVector<int64_t> subShape, distUnit;
+            if (layout.isWgLayout()) {
+              // for WgToSg, the subShape is either from sgData or computed as
+              // shape/sgLayout
+              std::tie(subShape, count) = computeTileShapeAndCount(
+                  shape, layout.getSgData(), layout.getSgLayout());
+              layout = layout.dropSgLayoutAndData();
+            } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
+              // for unrolling, the subShape is determined by inst_data
+              subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
+              count = computeProduct(shape) / computeProduct(subShape);
+              layout = layout.dropInstData();
+            }
+            newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding,
+                                               layout);
+          }
+
+          result.append(count, newTy);
+          return success();
+        });
+
+    converter.addSourceMaterialization(materializeCast);
+    converter.addTargetMaterialization(materializeCast);
+
+    mlir::ConversionTarget target(*context);
+    target.addLegalOp<UnrealizedConversionCastOp>();
+
+    mlir::RewritePatternSet patterns(context);
+    scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
+                                                         target);
+    (void)mlir::applyPartialConversion(op, target, std::move(patterns));
+  }
+}

>From 7b5e8f1193006591062592f5e8858c33113448fe Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 13 May 2025 20:02:45 +0000
Subject: [PATCH 10/55] partial working

---
 .../XeGPU/Transforms/XeGPUInstructionlize.cpp | 16 +++++++++++-----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 19 ++++++++++---------
 2 files changed, 21 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
index efc44aadb14e6..737600fe909fa 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
@@ -120,18 +120,22 @@ bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
 
 void XeGPUInstructionlizePass::runOnOperation() {
   MLIRContext *ctx = &getContext();
+  Operation *op = getOperation();
+
+  // first perform type conversion for SCF control folow ops
+  xegpu::doSCFStructuralTypeConversionWithTensorType(op);
+
   xegpu::UnrollOptions options;
   options.setFilterConstraint([&](Operation *op) -> LogicalResult {
     return needsUnroll(op) ? success() : failure();
   });
 
-  options.setNativeShapeFn(
-      [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
+  options.setNativeShapeFn([&](Operation *op) {
         return getTileShape(op);
       });
 
   options.setUnrolledTypesFn(
-      [&](ShapedType type, ArrayRef<int64_t> tileShape) -> SmallVector<Type> {
+      [&](ShapedType type, ArrayRef<int64_t> tileShape) {
         Type elemTy = type.getElementType();
         Type newTy;
 
@@ -149,8 +153,10 @@ void XeGPUInstructionlizePass::runOnOperation() {
         return SmallVector<Type>(computeProduct(*ratio), newTy);
       });
 
-  RewritePatternSet patterns(ctx);
+  GreedyRewriteConfig config;
+  config.setStrictness(GreedyRewriteStrictness::ExistingOps);
 
+  RewritePatternSet patterns(ctx);
   populateXeGPUUnrollPatterns(patterns, options);
-  (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
 }
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index e43aac4ce8dc0..cb2c4d40f8a6d 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -215,8 +215,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
     // LayoutAttr
 
     auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape,
-                                        DenseI32ArrayAttr sgDataAttr,
-                                        DenseI32ArrayAttr sgLayoutAttr) {
+                                          DenseI32ArrayAttr sgDataAttr,
+                                          DenseI32ArrayAttr sgLayoutAttr) {
       SmallVector<int64_t> tileShape;
       auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
       if (sgDataAttr)
@@ -224,8 +224,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
       else
         tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape);
       assert(tileShape.size() && "failed to compute tileShape");
-      SmallVector<int64_t> distUnit =
-          computeElementwiseMul(sgLayout, tileShape);
+      SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, tileShape);
       int count = computeProduct(shape) / computeProduct(distUnit);
       return std::make_pair(tileShape, count);
     };
@@ -249,8 +248,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
             if (layout.isWgLayout()) {
               // for WgToSg, the subShape is either from sgData or computed as
               // shape/sgLayout
-              std::tie(subShape, count) = computeTileShapeAndCount(
-                  shape, layout.getSgData(), layout.getSgLayout());
+              std::tie(subShape, count) = computeTileShapeAndCount(shape, layout.getSgData(), layout.getSgLayout());
             } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
               // for unrolling, the subShape is determined by inst_data
               subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
@@ -280,8 +278,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
             if (layout.isWgLayout()) {
               // for WgToSg, the subShape is either from sgData or computed as
               // shape/sgLayout
-              std::tie(subShape, count) = computeTileShapeAndCount(
-                  shape, layout.getSgData(), layout.getSgLayout());
+              std::tie(subShape, count) = computeTileShapeAndCount(shape, layout.getSgData(), layout.getSgLayout());
               layout = layout.dropSgLayoutAndData();
             } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
               // for unrolling, the subShape is determined by inst_data
@@ -298,7 +295,11 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
         });
 
     converter.addSourceMaterialization(materializeCast);
-    converter.addTargetMaterialization(materializeCast);
+    converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
+                                        ValueRange inputs, Location loc) {
+      return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+          .getResults();
+    });
 
     mlir::ConversionTarget target(*context);
     target.addLegalOp<UnrealizedConversionCastOp>();

>From 7f4e202ef2dbca83f19fe69eb486b315bf2d1853 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 13 May 2025 20:59:54 +0000
Subject: [PATCH 11/55] Fix CI

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 512cdca251c42..68a5f7faa2fbd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -151,9 +151,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       return failure();
     Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
-    ArrayRef<int64_t> sgShape =
+    SmallVector<int64_t> sgShape =
         llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
-    ArrayRef<int64_t> sgLayout =
+    SmallVector<int64_t> sgLayout =
         llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
 
     // TODO : Handle order attribute
@@ -188,7 +188,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
         xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
                                    layout.dropSgLayoutAndData());
     SmallVector<Value> newCreateNdOps;
-    for (const SmallVector<int64_t> &distUnitBaseAddr :
+    for (SmallVector<int64_t> distUnitBaseAddr :
          StaticTileOffsetRange(wgShape, distUnitShape)) {
       SmallVector<OpFoldResult> globalOffsets = calculateGlobalOffsets(
           rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr);

>From 2153a8a281726cf31b60f973907842e790bddc64 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 14 May 2025 18:26:48 +0000
Subject: [PATCH 12/55] Address feedback

---
 .../Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp  | 21 +++--
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 66 +++++++++-------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 76 ++++++++++++-------
 3 files changed, 94 insertions(+), 69 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 68a5f7faa2fbd..f8478289d3b91 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -1,4 +1,4 @@
-//===- XeGPUWgToSg.cpp - XeGPU WorkGroup to Subgroup Pass -------===//
+//===- XeGPUWgToSg.cpp - XeGPU Workgroup to Subgroup Pass -----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -25,15 +25,10 @@ namespace xegpu {
 } // namespace xegpu
 } // namespace mlir
 
-#define DEBUG_TYPE "xegpu-wg-to-sg"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-
 using namespace mlir;
 
 namespace {
 
-// clang-format off
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -42,11 +37,14 @@ namespace {
 ///    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
 ///       -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
 ///           sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-/// is converted to 9 subgroup level operations based on the sg_layout & sg_data:
+/// is converted to 9 subgroup level operations based on the sg_layout &
+/// sg_data:
 ///    %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
-///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
+///           lane_data = [1, 1]>>
 ///
-/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
+/// The sg_layout and sg_data attributes are dropped after the pass as they are
+/// no longer needed.
 ///
 /// 24x24 matrix distribution example:
 /// sg_layout = [4, 4], sg_data = [2, 2]
@@ -69,9 +67,8 @@ namespace {
 /// | 2x2 2x2 2x2 2x2 |
 /// +------------------------+
 ///
-/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
-/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
-// clang-format on
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
+/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index de2c548ec7ebb..3096759e3ac8c 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,62 +1,70 @@
 // RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
 
 gpu.module @test_round_robin_assignment {
-  // CHECK: test_create_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_create_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
-      // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
+      // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       gpu.return
     }
 
-  // CHECK: test_load_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_load_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-      // CHECK-COUNT-12: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<2x2xf32>
+      // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
+      // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
       %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
       gpu.return
     }
 
-  // CHECK: test_store_nd
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_store_nd
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_store_nd(%src: memref<24x32xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-      // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
+      // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
       %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
       xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       gpu.return
   }
 
-  // CHECK: test_update_nd
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_update_nd
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_update_nd(%src: memref<24x32xf32>){
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    // CHECK-COUNT-12: %[[UPDATE:.*]] = xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
+    // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
     %update = xegpu.update_nd_offset %tdesc, [0, 16] :  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
 
-  // CHECK: test_dpas
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
-  // CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
-  // CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
-  gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
-    // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    // CHECK-COUNT-12: %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    // CHECK-COUNT-9: %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    // CHECK-COUNT-144: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %load_b =  xegpu.load_nd %tdesc_b:  !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
-    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x24xf32> -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+  // CHECK-LABEL: test_dpas
+  // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>)
+  gpu.func @test_dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-4:  xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
+    // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b:  !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
+    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
     gpu.return
   }
 
-  // CHECK: test_prefetch_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_prefetch_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
     // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 1cae2c822d826..fdc10289b44f0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
 
 gpu.module @test_1_1_assignment {
-  // CHECK: test_create_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_create_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {  
   // CHECK: %[[SGID:.*]] = gpu.subgroup_id
   // CHECK: %[[C12:.*]] = arith.constant 12 : index
@@ -15,53 +15,71 @@ gpu.module @test_1_1_assignment {
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[ADD1:.*]] = index.add %[[MUL1]], %[[C0]]
   // CHECK: %[[ADD2:.*]] = index.add %[[MUL2]], %[[C0]]
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
+  // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
   // CHECK: gpu.return
   %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   gpu.return
   }
 
-  // CHECK: test_load_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_load_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
+    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<12x8xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
     gpu.return
   }
 
-  // CHECK: test_store_nd
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_store_nd
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_store_nd(%src: memref<24x32xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
-    // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
+    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<12x8xf32>
+    // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
+    // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
     xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     gpu.return
 }
 
-// CHECK: test_update_nd
-// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+// CHECK-LABEL: test_update_nd
+// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
 gpu.func @test_update_nd(%src: memref<24x32xf32>){
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-  // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+  // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
+  // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
   %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   gpu.return
 }
 
-// CHECK: test_dpas
-// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
+// CHECK-LABEL: test_dpas
+// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
+// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
 gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
-    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> 
-    // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
-    // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> 
-    // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<8x12xf32>
-    // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
+    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+    // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
+    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<12x8xf32>
+    // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
+    // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<8x12xf32>
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
+    // CHECK-SAME: {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
     %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
@@ -70,11 +88,13 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
-  // CHECK: test_prefetch_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_prefetch_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    // CHECK: xegpu.prefetch_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: xegpu.prefetch_nd %[[TDESC]]
+    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     gpu.return

>From e2eb9e63df30e9e84d3d09060ec493bc2b805f3d Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 15 May 2025 21:22:16 +0000
Subject: [PATCH 13/55] refactor pack and unpack

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  39 ++++-
 .../XeGPU/Transforms/XeGPUInstructionlize.cpp | 163 +++++++++++++-----
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  25 +--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 152 +++++++++++++++-
 4 files changed, 301 insertions(+), 78 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 4bcda3e3ac95f..b41da0ea6a276 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -15,6 +15,8 @@ namespace mlir {
 class VectorType;
 class OpOperand;
 class OpResult;
+class OpBuilder;
+class ValueRange;
 
 namespace xegpu {
 class LayoutAttr;
@@ -53,17 +55,46 @@ FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
 FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
                                                LayoutAttr layout);
 
+/// Return the attribute name for the OpOperand to attach LayoutAttr
+std::string getLayoutName(OpOperand &opr);
+
+/// Return the attribute name for the OpResult to attach LayoutAttr
+std::string getLayoutName(OpResult res);
+
 /// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
 /// values, the LayoutAttr is extracted from the TensorDescType itself. For
 /// other values, it is obtained from the attributes of the defining operation.
 /// Returns nullptr if no LayoutAttr is found.
 LayoutAttr getLayoutAttr(Value value);
 
-/// Retrieves the name for the LayoutAttr associated with a given OpOperand.
-std::string getLayoutName(OpOperand &opr);
+/// Retrieves the LayoutAttr associated with a given OpOperand. It will
+/// first check the operand_layout_{id} of the owner operation. If not found,
+/// it will check the operand itself and its defining op.
+LayoutAttr getLayoutAttr(OpOperand &opr);
 
-/// Retrieves the name for the LayoutAttr associated with a given OpResult.
-std::string getLayoutName(OpResult res);
+/// Sets the LayoutAttr for a given OpOperand by attaching it to the owner
+void setLayoutAttr(OpOperand &opr, LayoutAttr layout);
+
+/// Set the LayoutAttr for the given OpResult by attching it to the defining op
+void setLayoutAttr(OpResult result, LayoutAttr layout);
+
+/// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
+/// If the operation contains regions, it is also applied recursively to the
+/// contained operations
+void setLayoutAttrs(Operation *op,
+                    function_ref<LayoutAttr(Value)> getLayoutImpl);
+
+/// Extract a set of small vectors from a value with a given shape using
+/// vector.extract_stride_slice
+SmallVector<Value> extractVectorsWithShapeFromValue(OpBuilder &builder,
+                                                    Location loc, Value value,
+                                                    ArrayRef<int64_t> shape);
+
+/// Create a vector of shape from a set of values using
+/// vector.insert_stride_slice.
+Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
+                                      ValueRange values,
+                                      ArrayRef<int64_t> shape);
 
 /// Do type conversion for SCF structural ops, e.g., scf.for. Since VectorType
 /// cannot carry the layout attribute, they are converted into RankedTensorType
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
index 737600fe909fa..0e01c7e4d9763 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -45,6 +46,10 @@ class XeGPUInstructionlizePass final
   std::optional<SmallVector<int64_t>>
   getTileShape(TypedValue<ShapedType> value) const;
 
+  std::optional<SmallVector<int64_t>> getTileShape(OpOperand &operand) const;
+
+  std::optional<SmallVector<int64_t>> getTileShape(OpResult result) const;
+
   // Get the tile shape for a given operation.
   std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
 
@@ -67,20 +72,46 @@ XeGPUInstructionlizePass::getTileShape(TypedValue<ShapedType> value) const {
   return std::nullopt;
 }
 
+std::optional<SmallVector<int64_t>>
+XeGPUInstructionlizePass::getTileShape(OpOperand &operand) const {
+  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
+  if (layout && layout.isSgLayout()) {
+    if (auto inst_data = layout.getInstData())
+      return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+
+    if (auto type = dyn_cast<ShapedType>(operand.get().getType()))
+      return llvm::to_vector(type.getShape());
+  }
+  return std::nullopt;
+}
+
+std::optional<SmallVector<int64_t>>
+XeGPUInstructionlizePass::getTileShape(OpResult result) const {
+  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
+  if (layout && layout.isSgLayout()) {
+    if (auto inst_data = layout.getInstData())
+      return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
+
+    if (auto type = dyn_cast<ShapedType>(result.getType()))
+      return llvm::to_vector(type.getShape());
+  }
+  return std::nullopt;
+}
+
 std::optional<SmallVector<int64_t>>
 XeGPUInstructionlizePass::getTileShape(Operation *op) const {
   if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
-    return getTileShape(cast<TypedValue<ShapedType>>(op->getResult(0)));
+    return getTileShape(op->getOpResult(0));
   if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
-    return getTileShape(cast<TypedValue<ShapedType>>(op->getOperand(0)));
+    return getTileShape(op->getOpOperand(0));
   if (isa<xegpu::StoreNdOp>(op))
-    return getTileShape(cast<TypedValue<ShapedType>>(op->getOperand(1)));
+    return getTileShape(op->getOpOperand(1));
 
   if (isa<xegpu::DpasOp>(op)) {
-    auto a = cast<TypedValue<ShapedType>>(op->getOperand(0));
-    auto b = cast<TypedValue<ShapedType>>(op->getOperand(1));
-    std::optional<SmallVector<int64_t>> aTile = getTileShape(a);
-    std::optional<SmallVector<int64_t>> bTile = getTileShape(b);
+    std::optional<SmallVector<int64_t>> aTile =
+        getTileShape(op->getOpOperand(0));
+    std::optional<SmallVector<int64_t>> bTile =
+        getTileShape(op->getOpOperand(1));
 
     if (!aTile || aTile->size() != 2 || !bTile || bTile->size() != 2)
       return std::nullopt;
@@ -91,8 +122,8 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
 
     // semantic check for C
     if (op->getNumOperands() == 3) {
-      auto c = cast<TypedValue<ShapedType>>(op->getOperand(2));
-      std::optional<SmallVector<int64_t>> cTile = getTileShape(c);
+      std::optional<SmallVector<int64_t>> cTile =
+          getTileShape(op->getOpOperand(2));
       int64_t expectedCTile[2] = {(*aTile)[0], (*bTile)[1]};
       if (!cTile || !llvm::equal(*cTile, expectedCTile))
         return std::nullopt;
@@ -104,59 +135,101 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
 }
 
 bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
-  for (Value opr : op->getOperands()) {
-    if (auto value = dyn_cast<TypedValue<ShapedType>>(opr)) {
-      std::optional<SmallVector<int64_t>> tileShape = getTileShape(value);
-      // the tile should have the same rank as the origial type
-      if (!tileShape ||
-          tileShape->size() != static_cast<size_t>(value.getType().getRank()))
-        return false;
-      if (!llvm::equal(*tileShape, value.getType().getShape()))
-        return true;
-    }
+  if (isa<LoopLikeOpInterface>(op))
+    return false;
+
+  for (auto &opr : op->getOpOperands()) {
+    std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
+    auto shapedType = dyn_cast<ShapedType>(opr.get().getType());
+    if (!shapedType)
+      continue;
+
+    if (tileShape && !llvm::equal(*tileShape, shapedType.getShape()))
+      return true;
+  }
+
+  for (auto result : op->getOpResults()) {
+    std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
+    auto shapedType = dyn_cast<ShapedType>(result.getType());
+    if (!shapedType)
+      continue;
+
+    if (tileShape && !llvm::equal(*tileShape, shapedType.getShape()))
+      return true;
   }
   return false;
 }
 
 void XeGPUInstructionlizePass::runOnOperation() {
   MLIRContext *ctx = &getContext();
-  Operation *op = getOperation();
+  Operation *mod = getOperation();
+
+  // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
+  // This ensures that the LayoutAttr remains accessible even if the defining
+  // operation is replaced.
+  xegpu::setLayoutAttrs(mod, [&](Value v) { return xegpu::getLayoutAttr(v); });
 
-  // first perform type conversion for SCF control folow ops
-  xegpu::doSCFStructuralTypeConversionWithTensorType(op);
+  // Perform type conversion for SCF control folow ops
+  xegpu::doSCFStructuralTypeConversionWithTensorType(mod);
 
   xegpu::UnrollOptions options;
   options.setFilterConstraint([&](Operation *op) -> LogicalResult {
     return needsUnroll(op) ? success() : failure();
   });
 
-  options.setNativeShapeFn([&](Operation *op) {
-        return getTileShape(op);
-      });
+  options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
 
-  options.setUnrolledTypesFn(
-      [&](ShapedType type, ArrayRef<int64_t> tileShape) {
-        Type elemTy = type.getElementType();
-        Type newTy;
+  options.setUnrolledTypesFn([&](ShapedType type, ArrayRef<int64_t> tileShape) {
+    Type elemTy = type.getElementType();
+    Type newTy;
 
-        if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
-          newTy = xegpu::TensorDescType::get(
-              ctx, tileShape, elemTy, tdescTy.getEncoding(),
-              tdescTy.getLayoutAttr().dropInstData());
-        else
-          newTy = type.clone(tileShape, elemTy);
+    if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
+      newTy = xegpu::TensorDescType::get(
+          ctx, tileShape, elemTy, tdescTy.getEncoding(),
+          tdescTy.getLayoutAttr().dropInstData());
+    else
+      newTy = type.clone(tileShape, elemTy);
 
-        std::optional<SmallVector<int64_t>> ratio =
-            computeShapeRatio(type.getShape(), tileShape);
-        assert(ratio &&
-               "The shape of the type must be a multiple of tileShape.");
-        return SmallVector<Type>(computeProduct(*ratio), newTy);
-      });
-
-  GreedyRewriteConfig config;
-  config.setStrictness(GreedyRewriteStrictness::ExistingOps);
+    std::optional<SmallVector<int64_t>> ratio =
+        computeShapeRatio(type.getShape(), tileShape);
+    assert(ratio && "The shape of the type must be a multiple of tileShape.");
+    return SmallVector<Type>(computeProduct(*ratio), newTy);
+  });
 
   RewritePatternSet patterns(ctx);
   populateXeGPUUnrollPatterns(patterns, options);
-  (void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
+  (void)applyPatternsGreedily(mod, std::move(patterns));
+
+  mod->walk([&](UnrealizedConversionCastOp castOp) {
+    ValueRange inputs = castOp.getInputs();
+    ValueRange outputs = castOp.getOutputs();
+
+    if (inputs.size() == 1 && outputs.size() == 1) {
+      castOp->replaceAllUsesWith(inputs);
+      castOp->erase();
+    }
+
+    VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
+    VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
+    if (inputTy && outputTy) {
+      OpBuilder builder(castOp);
+      // unpack
+      if (inputs.size() > 1 && outputs.size() == 1) {
+        ArrayRef<int64_t> shape = outputTy.getShape();
+        Value result = xegpu::createVectorWithShapeFromValues(
+            builder, castOp.getLoc(), inputs, shape);
+        castOp->replaceAllUsesWith(ValueRange(result));
+        castOp->erase();
+      }
+
+      // pack
+      if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
+        ArrayRef<int64_t> tileShape = outputTy.getShape();
+        SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
+            builder, castOp.getLoc(), inputs[0], tileShape);
+        castOp->replaceAllUsesWith(results);
+        castOp->erase();
+      }
+    }
+  });
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 44d45dd2eaec0..d9f69158f95eb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
@@ -74,17 +75,7 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
       assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
              "Expecting blockSize size to match the rank of destTy.");
       auto shape = vecTy.getShape();
-      auto zeroAttr = rewriter.getZeroAttr(vecTy.getElementType());
-
-      Value result = rewriter.create<arith::ConstantOp>(
-          loc, vecTy, DenseElementsAttr::get(vecTy, zeroAttr));
-      for (auto [src, offsets] :
-           llvm::zip_equal(srcs, StaticTileOffsetRange(shape, blockSize))) {
-        SmallVector<int64_t> staticStrides(offsets.size(), 1);
-        result = rewriter.create<vector::InsertStridedSliceOp>(
-            loc, src, result, offsets, staticStrides);
-      }
-      return result;
+      return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
     }
 
     if (isa<xegpu::TensorDescType>(destTy)) {
@@ -109,16 +100,8 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
     if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
       assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
              "Expecting blockSize size to match the rank of src.");
-      auto shape = vecTy.getShape();
-      SmallVector<Value> results;
-      for (SmallVector<int64_t> offsets :
-           StaticTileOffsetRange(shape, blockSize)) {
-        SmallVector<int64_t> staticStrides(offsets.size(), 1);
-        auto slice = rewriter.create<vector::ExtractStridedSliceOp>(
-            loc, src, offsets, blockSize, staticStrides);
-        results.push_back(slice);
-      }
-      return results;
+      return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
+                                                     blockSize);
     }
 
     if (isa<xegpu::TensorDescType>(src.getType())) {
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index cb2c4d40f8a6d..60c8493f552d8 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -14,15 +14,26 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <cstdint>
 #include <numeric>
 
 using namespace mlir;
 
+/// convert ArrayRef<ValueRange> into SmallVector<Value>
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
 FailureOr<VectorType>
 mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
   auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
@@ -90,6 +101,16 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
   return xegpu::getDistributedVectorType(helperTdescTy);
 }
 
+std::string xegpu::getLayoutName(OpOperand &opr) {
+  const StringRef prefix("layout_operand_");
+  return llvm::formatv("{0}{1}", prefix, opr.getOperandNumber()).str();
+}
+
+std::string xegpu::getLayoutName(OpResult res) {
+  const StringRef prefix = "layout_result_";
+  return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
+}
+
 xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
   if (!value)
     return nullptr;
@@ -121,14 +142,86 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
   return nullptr;
 }
 
-std::string xegpu::getLayoutName(OpOperand &opr) {
-  const StringRef prefix("layout_operand_");
-  return llvm::formatv("{0}{1}", prefix, opr.getOperandNumber()).str();
+xegpu::LayoutAttr xegpu::getLayoutAttr(OpOperand &opr) {
+  Operation *op = opr.getOwner();
+  std::string layoutName = xegpu::getLayoutName(opr);
+  if (op->hasAttr(layoutName))
+    return op->getAttrOfType<xegpu::LayoutAttr>(layoutName);
+  return getLayoutAttr(opr.get());
 }
 
-std::string xegpu::getLayoutName(OpResult res) {
-  const StringRef prefix = "layout_result_";
-  return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
+void xegpu::setLayoutAttr(OpOperand &opr, LayoutAttr layout) {
+  auto owner = opr.getOwner();
+  std::string name = xegpu::getLayoutName(opr);
+  if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
+    owner->setAttr(name, layout);
+}
+
+void xegpu::setLayoutAttr(OpResult result, LayoutAttr layout) {
+  Operation *owner = result.getOwner();
+  std::string name = xegpu::getLayoutName(result);
+  if (layout && !owner->hasAttr(name))
+    owner->setAttr(name, layout);
+}
+
+void xegpu::setLayoutAttrs(Operation *mod,
+                           function_ref<LayoutAttr(Value)> getLayoutImpl) {
+  mod->walk([&](Operation *op) {
+    for (OpResult result : op->getOpResults()) {
+      auto layout = getLayoutImpl(result);
+      setLayoutAttr(result, layout);
+    }
+    for (OpOperand &opr : op->getOpOperands()) {
+      auto layout = getLayoutImpl(opr.get());
+      setLayoutAttr(opr, layout);
+    }
+  });
+}
+
+SmallVector<Value>
+xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
+                                        Value value, ArrayRef<int64_t> shape) {
+  auto vecTy = dyn_cast<VectorType>(value.getType());
+  if (!vecTy)
+    return {value};
+
+  ArrayRef<int64_t> srcShape = vecTy.getShape();
+  if (!computeShapeRatio(srcShape, shape))
+    return {value};
+
+  SmallVector<Value> result;
+  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+    SmallVector<int64_t> staticStrides(offsets.size(), 1);
+    result.push_back(builder.create<vector::ExtractStridedSliceOp>(
+        loc, value, offsets, shape, staticStrides));
+  }
+
+  return result;
+}
+
+Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
+                                             ValueRange values,
+                                             ArrayRef<int64_t> shape) {
+  VectorType inputTy = dyn_cast<VectorType>(values[0].getType());
+  assert(llvm::all_of(values.getTypes(),
+                      [&](Type type) { return type == inputTy; }) &&
+         "values must be of the same VectorType");
+
+  Type elemTy = inputTy.getElementType();
+  ArrayRef<int64_t> tileShape = inputTy.getShape();
+
+  VectorType resultTy = VectorType::get(shape, elemTy);
+  auto zeroAttr = builder.getZeroAttr(elemTy);
+  Value result = builder.create<arith::ConstantOp>(
+      loc, resultTy, DenseElementsAttr::get(resultTy, zeroAttr));
+
+  for (auto [src, offsets] :
+       llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
+    SmallVector<int64_t> staticStrides(offsets.size(), 1);
+    result = builder.create<vector::InsertStridedSliceOp>(
+        loc, src, result, offsets, staticStrides);
+  }
+  return result;
 }
 
 void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
@@ -213,7 +306,6 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
 
   { // perform the conversion from RankedTensorType to VectorType based on the
     // LayoutAttr
-
     auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape,
                                           DenseI32ArrayAttr sgDataAttr,
                                           DenseI32ArrayAttr sgLayoutAttr) {
@@ -302,9 +394,53 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
     });
 
     mlir::ConversionTarget target(*context);
-    target.addLegalOp<UnrealizedConversionCastOp>();
+    target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
+        [&](UnrealizedConversionCastOp op) {
+          auto isTensorTy = [&](Type type) {
+            return isa<RankedTensorType>(type);
+          };
+          if (llvm::any_of(op->getOperandTypes(), isTensorTy) ||
+              llvm::any_of(op->getResultTypes(), isTensorTy))
+            return false;
+          return true;
+        });
+
+    class UnrealizedConversionCastOpPattern
+        : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
+      using OpConversionPattern<
+          mlir::UnrealizedConversionCastOp>::OpConversionPattern;
+
+      mlir::LogicalResult
+      matchAndRewrite(mlir::UnrealizedConversionCastOp op,
+                      OneToNOpAdaptor adaptor,
+                      ConversionPatternRewriter &rewriter) const override {
+        auto inputs = op.getOperands();
+        auto outputs = op.getOutputs();
+
+        if (inputs.size() != 1 || outputs.size() != 1)
+          return failure();
+
+        auto inputTy = inputs[0].getType();
+        auto outputTy = outputs[0].getType();
+
+        if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
+          rewriter.replaceOpWithMultiple(op, adaptor.getInputs());
+          return success();
+        }
+
+        if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
+          SmallVector<Value> values = flattenValues(adaptor.getInputs());
+          auto newOp = rewriter.create<UnrealizedConversionCastOp>(
+              op.getLoc(), outputTy, values);
+          rewriter.replaceOp(op, newOp);
+          return success();
+        }
+        return failure();
+      }
+    };
 
     mlir::RewritePatternSet patterns(context);
+    patterns.insert<UnrealizedConversionCastOpPattern>(context);
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
     (void)mlir::applyPartialConversion(op, target, std::move(patterns));

>From 46686f5e36744c2639e4eb4bfe30f84b7580f306 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 15 May 2025 21:53:42 +0000
Subject: [PATCH 14/55] change name to XeGPUWgToSgDistribute

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  2 +-
 .../Dialect/XeGPU/Transforms/Transforms.h     |  2 +-
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |  2 +-
 ...PUWgToSg.cpp => XeGPUWgToSgDistribute.cpp} | 32 ++++++++-----------
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |  2 +-
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  2 +-
 6 files changed, 18 insertions(+), 24 deletions(-)
 rename mlir/lib/Dialect/XeGPU/Transforms/{XeGPUWgToSg.cpp => XeGPUWgToSgDistribute.cpp} (94%)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index bdea88cfd7022..0be9fceb25ef1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -33,7 +33,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
       "Print the result of the subgroup map propagation analysis and exit.">];
 }
 
-def XeGPUWgToSg : Pass<"xegpu-wg-to-sg", "::mlir::gpu::GPUModuleOp"> {
+def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute", "::mlir::gpu::GPUModuleOp"> {
   let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
   let description = [{
     This transform pass distributes the workgroup level computation to
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 1029c66f97461..44b81796b1313 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -62,7 +62,7 @@ void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
 
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
-void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns);
+void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns);
 
 /// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
 /// Users can control whether an operation to be unrolled or not, as well as
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 81938ba1d5ba3..837303b04e9d7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
   XeGPUSubgroupDistribute.cpp
   XeGPUUnroll.cpp
-  XeGPUWgToSg.cpp
+  XeGPUWgToSgDistribute.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
similarity index 94%
rename from mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
rename to mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index f8478289d3b91..6406809b8b9c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1,4 +1,4 @@
-//===- XeGPUWgToSg.cpp - XeGPU Workgroup to Subgroup Pass -----------------===//
+//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -13,14 +13,12 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/Debug.h"
 #include <mlir/Dialect/GPU/IR/GPUDialect.h>
 #include <mlir/Dialect/Index/IR/IndexOps.h>
-#include <numeric>
 
 namespace mlir {
 namespace xegpu {
-#define GEN_PASS_DEF_XEGPUWGTOSG
+#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
 } // namespace xegpu
 } // namespace mlir
@@ -98,12 +96,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
             rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
   }
 
-  // Create a constant index value
-  Value createConstantIndex(ConversionPatternRewriter &rewriter, Location loc,
-                            int64_t value) const {
-    return rewriter.create<arith::ConstantIndexOp>(loc, value);
-  }
-
   // Calculate offset for each subgroup
   SmallVector<OpFoldResult>
   calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
@@ -112,9 +104,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
                          const SmallVector<int64_t> &distUnitBaseAddr) const {
 
     Value constOffsetX =
-        createConstantIndex(rewriter, loc, distUnitBaseAddr[0]);
+        rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[0]);
     Value constOffsetY =
-        createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
+        rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[1]);
 
     Value offsetX =
         rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
@@ -162,8 +154,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     SmallVector<Value> sgDataDim(sgShape.size());
 
     for (size_t i = 0; i < sgLayout.size(); i++) {
-      sgLayoutDim[i] = createConstantIndex(rewriter, loc, sgLayout[i]);
-      sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
+      sgLayoutDim[i] =
+          rewriter.create<arith::ConstantIndexOp>(loc, sgLayout[i]);
+      sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
     }
 
     // Delinearize the 1D subgroup id into 2d
@@ -278,8 +271,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
       return failure();
 
     SmallVector<Value> newDpasOps;
-    size_t i = 0;
     for (auto aVec : adaptor.getLhs()) {
+      size_t i = 0;
       for (auto bVec : adaptor.getRhs()) {
 
         llvm::SmallVector<Value> operands({aVec, bVec});
@@ -325,7 +318,7 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
 
 namespace mlir {
 namespace xegpu {
-void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
+void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
       patterns.getContext());
@@ -334,12 +327,13 @@ void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
 } // namespace mlir
 
 namespace {
-struct XeGPUWgToSgPass : public xegpu::impl::XeGPUWgToSgBase<XeGPUWgToSgPass> {
+struct XeGPUWgToSgDistributePass
+    : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
   void runOnOperation() override;
 };
 } // namespace
 
-void XeGPUWgToSgPass::runOnOperation() {
+void XeGPUWgToSgDistributePass::runOnOperation() {
   MLIRContext *ctx = &getContext();
   RewritePatternSet patterns(ctx);
   ConversionTarget target(*ctx);
@@ -377,7 +371,7 @@ void XeGPUWgToSgPass::runOnOperation() {
 
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
-  xegpu::populateXeGPUWgToSgPatterns(patterns);
+  xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
     return signalPassFailure();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 3096759e3ac8c..321cc0510a24c 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
 gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: test_create_nd_tdesc
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index fdc10289b44f0..3bd95ee775db3 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
 gpu.module @test_1_1_assignment {
   // CHECK-LABEL: test_create_nd_tdesc

>From 6ec3604310f3abf10d576162b14e0820839056e5 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 15 May 2025 23:42:54 +0000
Subject: [PATCH 15/55] cleanup layout attr

---
 .../XeGPU/Transforms/XeGPUInstructionlize.cpp | 72 ++++++++++++-------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  6 +-
 2 files changed, 50 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
index 0e01c7e4d9763..fba0f882ef632 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
@@ -32,6 +32,39 @@ using namespace mlir;
 
 namespace {
 
+void resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
+  ValueRange inputs = castOp.getInputs();
+  ValueRange outputs = castOp.getOutputs();
+
+  if (inputs.size() == 1 && outputs.size() == 1) {
+    castOp->replaceAllUsesWith(inputs);
+    castOp->erase();
+  }
+
+  VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
+  VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
+  if (inputTy && outputTy) {
+    OpBuilder builder(castOp);
+    // unpack
+    if (inputs.size() > 1 && outputs.size() == 1) {
+      ArrayRef<int64_t> shape = outputTy.getShape();
+      Value result = xegpu::createVectorWithShapeFromValues(
+          builder, castOp.getLoc(), inputs, shape);
+      castOp->replaceAllUsesWith(ValueRange(result));
+      castOp->erase();
+    }
+
+    // pack
+    if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
+      ArrayRef<int64_t> tileShape = outputTy.getShape();
+      SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
+          builder, castOp.getLoc(), inputs[0], tileShape);
+      castOp->replaceAllUsesWith(results);
+      castOp->erase();
+    }
+  }
+}
+
 /// Unroll XeGPU ops to their instruction-level representation.
 class XeGPUInstructionlizePass final
     : public xegpu::impl::XeGPUInstructionlizeBase<XeGPUInstructionlizePass> {
@@ -200,35 +233,22 @@ void XeGPUInstructionlizePass::runOnOperation() {
   populateXeGPUUnrollPatterns(patterns, options);
   (void)applyPatternsGreedily(mod, std::move(patterns));
 
-  mod->walk([&](UnrealizedConversionCastOp castOp) {
-    ValueRange inputs = castOp.getInputs();
-    ValueRange outputs = castOp.getOutputs();
+  mod->walk([&](Operation *op) {
+    if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+      resolveUnrealizedConversionCastOp(castOp);
 
-    if (inputs.size() == 1 && outputs.size() == 1) {
-      castOp->replaceAllUsesWith(inputs);
-      castOp->erase();
+    for (OpOperand &opr : op->getOpOperands()) {
+      std::string name = xegpu::getLayoutName(opr);
+      if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name))
+        op->removeAttr(name);
     }
 
-    VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
-    VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
-    if (inputTy && outputTy) {
-      OpBuilder builder(castOp);
-      // unpack
-      if (inputs.size() > 1 && outputs.size() == 1) {
-        ArrayRef<int64_t> shape = outputTy.getShape();
-        Value result = xegpu::createVectorWithShapeFromValues(
-            builder, castOp.getLoc(), inputs, shape);
-        castOp->replaceAllUsesWith(ValueRange(result));
-        castOp->erase();
-      }
-
-      // pack
-      if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
-        ArrayRef<int64_t> tileShape = outputTy.getShape();
-        SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
-            builder, castOp.getLoc(), inputs[0], tileShape);
-        castOp->replaceAllUsesWith(results);
-        castOp->erase();
+    for (OpResult result : op->getOpResults()) {
+      std::string name = xegpu::getLayoutName(result);
+      if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) {
+        op->removeAttr(name);
+        if (!isa<LoopLikeOpInterface>(op))
+          xegpu::setLayoutAttr(result, layout.dropInstData());
       }
     }
   });
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 60c8493f552d8..023e445206440 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -115,7 +115,8 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
   if (!value)
     return nullptr;
 
-  if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(value.getType()))
+  if (auto tdescTy =
+          dyn_cast_if_present<xegpu::TensorDescType>(value.getType()))
     return tdescTy.getLayoutAttr();
 
   if (auto result = dyn_cast<OpResult>(value)) {
@@ -366,7 +367,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
           Type newTy = type;
 
           if (xegpu::LayoutAttr layout = type.getLayoutAttr()) {
-            SmallVector<int64_t> subShape, distUnit;
+            SmallVector<int64_t> subShape(shape);
             if (layout.isWgLayout()) {
               // for WgToSg, the subShape is either from sgData or computed as
               // shape/sgLayout
@@ -378,6 +379,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
               count = computeProduct(shape) / computeProduct(subShape);
               layout = layout.dropInstData();
             }
+
             newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding,
                                                layout);
           }

>From b8da87e3d9f85c12a89cccf1092dcbbac22732b2 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 16 May 2025 05:14:56 +0000
Subject: [PATCH 16/55] Use getMixedOffsets

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 43 +++++++++----------
 1 file changed, 20 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6406809b8b9c7..7c5a6d362c3d1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -70,23 +70,6 @@ namespace {
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 
-  // Helper to extract mixed offsets into a Value array
-  SmallVector<Value> extractOffsets(ConversionPatternRewriter &rewriter,
-                                    xegpu::CreateNdDescOp op) const {
-    llvm::SmallVector<Value> offsets;
-    auto staticOffsets = op.getStaticOffsets();
-    auto dynamicOffsets = op.getOffsets();
-
-    for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
-      if (ShapedType::isDynamic(staticOffsets[i]))
-        offsets.push_back(dynamicOffsets[j++]);
-      else
-        offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
-            op.getLoc(), staticOffsets[i]));
-    }
-    return offsets;
-  }
-
   // Convert linear subgroup ID to 2D coordinates
   // TODO: Delinearize for nD
   SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
@@ -99,7 +82,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   // Calculate offset for each subgroup
   SmallVector<OpFoldResult>
   calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
-                         const SmallVector<Value> &originalOffsets,
+                         const SmallVector<OpFoldResult> &originalOffsets,
                          const SmallVector<Value> &localOffset,
                          const SmallVector<int64_t> &distUnitBaseAddr) const {
 
@@ -116,10 +99,24 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     size_t lastDimIndex = originalOffsets.size() - 1;
     size_t secondLastDimIndex = lastDimIndex - 1;
 
-    Value globalOffsetX = rewriter.createOrFold<index::AddOp>(
-        loc, originalOffsets[secondLastDimIndex], offsetX);
-    Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
-        loc, originalOffsets[lastDimIndex], offsetY);
+    // Convert originalOffsets to Value
+    auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
+      if (auto val = ofr.dyn_cast<Value>())
+        return val;
+      if (auto attr = ofr.dyn_cast<Attribute>()) {
+        int64_t staticOffset = cast<IntegerAttr>(attr).getInt();
+        return rewriter.create<arith::ConstantIndexOp>(loc, staticOffset);
+      }
+      llvm_unreachable("Unsupported OpFoldResult kind");
+    };
+
+    Value origOffsetX =
+        getValueFromOpFoldResult(originalOffsets[secondLastDimIndex]);
+    Value origOffsetY = getValueFromOpFoldResult(originalOffsets[lastDimIndex]);
+    Value globalOffsetX =
+        rewriter.createOrFold<index::AddOp>(loc, origOffsetX, offsetX);
+    Value globalOffsetY =
+        rewriter.createOrFold<index::AddOp>(loc, origOffsetY, offsetY);
 
     SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
                                             originalOffsets.end());
@@ -172,7 +169,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
           rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
     }
 
-    SmallVector<Value> originalOffsets = extractOffsets(rewriter, op);
+    SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
 
     xegpu::TensorDescType newTdescTy =
         xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),

>From bc69a8de7e0d436a7718fc2b30ee4bbd7861e5a4 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 16 May 2025 14:10:26 +0000
Subject: [PATCH 17/55] check in elemwise support

---
 .../Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
index fba0f882ef632..078b674de8d4f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
@@ -164,6 +164,10 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
 
     return SmallVector<int64_t>({(*aTile)[0], (*aTile)[1], (*bTile)[1]});
   }
+
+  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
+    return getTileShape(op->getOpResult(0));
+
   return std::nullopt;
 }
 
@@ -230,7 +234,14 @@ void XeGPUInstructionlizePass::runOnOperation() {
   });
 
   RewritePatternSet patterns(ctx);
+
+  vector::UnrollVectorOptions vectorOptions;
+  // vectorOptions.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
+  vectorOptions.setNativeShapeFn(options.nativeShape);
+
   populateXeGPUUnrollPatterns(patterns, options);
+  vector::populateVectorUnrollPatterns(patterns, vectorOptions);
+
   (void)applyPatternsGreedily(mod, std::move(patterns));
 
   mod->walk([&](Operation *op) {

>From 4fc75402332a5062eaa20b51f20ef54b4e5281ac Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 16 May 2025 14:43:59 +0000
Subject: [PATCH 18/55] check in unit test

---
 .../Dialect/XeGPU/xegpu-instructionlize.mlir  | 123 ++++++++++++++++++
 1 file changed, 123 insertions(+)
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-instructionlize.mlir

diff --git a/mlir/test/Dialect/XeGPU/xegpu-instructionlize.mlir b/mlir/test/Dialect/XeGPU/xegpu-instructionlize.mlir
new file mode 100644
index 0000000000000..888684789cc8c
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-instructionlize.mlir
@@ -0,0 +1,123 @@
+// RUN: mlir-opt --xegpu-instructionlize -split-input-file %s | FileCheck %s
+
+
+#a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+#b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
+#c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+
+#l1 = #xegpu.layout<inst_data = [8, 16]>
+#l2 = #xegpu.layout<inst_data = [16, 16]>
+
+gpu.module @test_kernel {
+  gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c16 : index
+    %n = arith.muli %block_id_y, %c32 : index
+
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #c>
+    %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #c> -> vector<16x32xf32>
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #a>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
+      -> (!xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>) {
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #a> -> vector<16x32xf16>
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+      //CHECK-COUNT-8: xegpu.dpas {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+      %c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #c}: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #a>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [16, 1]>>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<32x32xf16, #b>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c
+        : !xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>
+    }
+    //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
+    xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c>
+    gpu.return
+  }
+
+  //-----
+  gpu.func @test_gemm_simple(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c16 : index
+    %n = arith.muli %block_id_y, %c32 : index
+
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #l1>
+    %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #l1> -> vector<16x32xf32>
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l1>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #l2>
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
+      -> (!xegpu.tensor_desc<16x32xf16, #l1>, !xegpu.tensor_desc<32x32xf16, #l2>, vector<16x32xf32>) {
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #l1> -> vector<16x32xf16>
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #l2> -> vector<32x32xf16>
+      //CHECK-COUNT-8: xegpu.dpas {{.*}} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+      %c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #l1}: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l1>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x16xf16>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<32x32xf16, #l2>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c
+        : !xegpu.tensor_desc<16x32xf16, #l1>, !xegpu.tensor_desc<32x32xf16, #l2>, vector<16x32xf32>
+    }
+    //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #l1>
+    gpu.return
+  }
+
+  //-----
+
+  gpu.func @test_gemm_a_preop(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c16 : index
+    %n = arith.muli %block_id_y, %c32 : index
+
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #c>
+    %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #c> -> vector<16x32xf32>
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #a>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
+      -> (!xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>) {
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #a> -> vector<16x32xf16>
+      //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+      //CHECK-COUNT-4: math.exp {{.*}} : vector<8x16xf16>
+      %e = math.exp %a {layout_result_0 = #a} : vector<16x32xf16>
+      //CHECK-COUNT-8: xegpu.dpas {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
+      %c = xegpu.dpas %e, %b, %arg2 {layout_result_0 = #c}: vector<16x32xf16>, vector<32x32xf16>, vector<16x32xf32> -> vector<16x32xf32>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #a>
+      //CHECK-COUNT-4: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [16, 1]>>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<32x32xf16, #b>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c
+        : !xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>
+    }
+    //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
+    xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c>
+    gpu.return
+  }}

>From 132f15e7400b92b61801ca0bf013be66a95c54d1 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 16 May 2025 15:06:25 +0000
Subject: [PATCH 19/55] fix format

---
 .../XeGPU/Transforms/XeGPUInstructionlize.cpp     |  1 -
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp       | 15 +++++++++------
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
index 078b674de8d4f..f0ebe2321f8f1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
@@ -236,7 +236,6 @@ void XeGPUInstructionlizePass::runOnOperation() {
   RewritePatternSet patterns(ctx);
 
   vector::UnrollVectorOptions vectorOptions;
-  // vectorOptions.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
   vectorOptions.setNativeShapeFn(options.nativeShape);
 
   populateXeGPUUnrollPatterns(patterns, options);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 023e445206440..14b2b909e143a 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -308,8 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
   { // perform the conversion from RankedTensorType to VectorType based on the
     // LayoutAttr
     auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape,
-                                          DenseI32ArrayAttr sgDataAttr,
-                                          DenseI32ArrayAttr sgLayoutAttr) {
+                                        DenseI32ArrayAttr sgDataAttr,
+                                        DenseI32ArrayAttr sgLayoutAttr) {
       SmallVector<int64_t> tileShape;
       auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
       if (sgDataAttr)
@@ -317,7 +317,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
       else
         tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape);
       assert(tileShape.size() && "failed to compute tileShape");
-      SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, tileShape);
+      SmallVector<int64_t> distUnit =
+          computeElementwiseMul(sgLayout, tileShape);
       int count = computeProduct(shape) / computeProduct(distUnit);
       return std::make_pair(tileShape, count);
     };
@@ -341,7 +342,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
             if (layout.isWgLayout()) {
               // for WgToSg, the subShape is either from sgData or computed as
               // shape/sgLayout
-              std::tie(subShape, count) = computeTileShapeAndCount(shape, layout.getSgData(), layout.getSgLayout());
+              std::tie(subShape, count) = computeTileShapeAndCount(
+                  shape, layout.getSgData(), layout.getSgLayout());
             } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
               // for unrolling, the subShape is determined by inst_data
               subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
@@ -371,7 +373,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
             if (layout.isWgLayout()) {
               // for WgToSg, the subShape is either from sgData or computed as
               // shape/sgLayout
-              std::tie(subShape, count) = computeTileShapeAndCount(shape, layout.getSgData(), layout.getSgLayout());
+              std::tie(subShape, count) = computeTileShapeAndCount(
+                  shape, layout.getSgData(), layout.getSgLayout());
               layout = layout.dropSgLayoutAndData();
             } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
               // for unrolling, the subShape is determined by inst_data
@@ -390,7 +393,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
 
     converter.addSourceMaterialization(materializeCast);
     converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
-                                        ValueRange inputs, Location loc) {
+                                           ValueRange inputs, Location loc) {
       return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
           .getResults();
     });

>From aa4ba9c32d9ca14daec16bc98b27e4bb9d1f5282 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 16 May 2025 15:21:18 +0000
Subject: [PATCH 20/55] roll back pass name

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  2 +-
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |  2 +-
 ...UInstructionlize.cpp => XeGPUBlocking.cpp} | 22 +++++++++----------
 ...structionlize.mlir => xegpu-blocking.mlir} |  2 +-
 4 files changed, 14 insertions(+), 14 deletions(-)
 rename mlir/lib/Dialect/XeGPU/Transforms/{XeGPUInstructionlize.cpp => XeGPUBlocking.cpp} (92%)
 rename mlir/test/Dialect/XeGPU/{xegpu-instructionlize.mlir => xegpu-blocking.mlir} (99%)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 54782933fe5f8..b3883605b74f2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -38,7 +38,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   ];
 }
 
-def XeGPUInstructionlize: Pass<"xegpu-instructionlize"> {
+def XeGPUBlocking: Pass<"xegpu-blocking"> {
   let summary = "Instructionlize XeGPU ops";
   let description = [{
     The pass unrolls XeGPU ops working on large shapes into ops working on small shapes
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 1d94b4c4c03ac..adbbdaac8fc06 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,6 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
+  XeGPUBlocking.cpp
   XeGPUFoldAliasOps.cpp
-  XeGPUInstructionlize.cpp
   XeGPUSubgroupDistribute.cpp
   XeGPUUnroll.cpp
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
similarity index 92%
rename from mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
rename to mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index f0ebe2321f8f1..1587cbdfed2cc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUInstructionlize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -1,4 +1,4 @@
-//===---- XeGPUInstructionlize.cpp -- XeGPU Instructionlize Pass ----------===//
+//===---- XeGPUBlocking.cpp ---- XeGPU Instructionlize Pass ---------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -20,12 +20,12 @@
 
 namespace mlir {
 namespace xegpu {
-#define GEN_PASS_DEF_XEGPUINSTRUCTIONLIZE
+#define GEN_PASS_DEF_XEGPUBLOCKING
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
 } // namespace xegpu
 } // namespace mlir
 
-#define DEBUG_TYPE "xegpu-instructionlize"
+#define DEBUG_TYPE "xegpu-blocking"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 
 using namespace mlir;
@@ -66,8 +66,8 @@ void resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
 }
 
 /// Unroll XeGPU ops to their instruction-level representation.
-class XeGPUInstructionlizePass final
-    : public xegpu::impl::XeGPUInstructionlizeBase<XeGPUInstructionlizePass> {
+class XeGPUBlockingPass final
+    : public xegpu::impl::XeGPUBlockingBase<XeGPUBlockingPass> {
 public:
   void runOnOperation() override;
 
@@ -94,7 +94,7 @@ class XeGPUInstructionlizePass final
 } // namespace
 
 std::optional<SmallVector<int64_t>>
-XeGPUInstructionlizePass::getTileShape(TypedValue<ShapedType> value) const {
+XeGPUBlockingPass::getTileShape(TypedValue<ShapedType> value) const {
   assert(value && "value must be non-null");
   xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value);
   if (layout && layout.isSgLayout()) {
@@ -106,7 +106,7 @@ XeGPUInstructionlizePass::getTileShape(TypedValue<ShapedType> value) const {
 }
 
 std::optional<SmallVector<int64_t>>
-XeGPUInstructionlizePass::getTileShape(OpOperand &operand) const {
+XeGPUBlockingPass::getTileShape(OpOperand &operand) const {
   xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
   if (layout && layout.isSgLayout()) {
     if (auto inst_data = layout.getInstData())
@@ -119,7 +119,7 @@ XeGPUInstructionlizePass::getTileShape(OpOperand &operand) const {
 }
 
 std::optional<SmallVector<int64_t>>
-XeGPUInstructionlizePass::getTileShape(OpResult result) const {
+XeGPUBlockingPass::getTileShape(OpResult result) const {
   xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
   if (layout && layout.isSgLayout()) {
     if (auto inst_data = layout.getInstData())
@@ -132,7 +132,7 @@ XeGPUInstructionlizePass::getTileShape(OpResult result) const {
 }
 
 std::optional<SmallVector<int64_t>>
-XeGPUInstructionlizePass::getTileShape(Operation *op) const {
+XeGPUBlockingPass::getTileShape(Operation *op) const {
   if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
     return getTileShape(op->getOpResult(0));
   if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
@@ -171,7 +171,7 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
   return std::nullopt;
 }
 
-bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
+bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
   if (isa<LoopLikeOpInterface>(op))
     return false;
 
@@ -197,7 +197,7 @@ bool XeGPUInstructionlizePass::needsUnroll(Operation *op) const {
   return false;
 }
 
-void XeGPUInstructionlizePass::runOnOperation() {
+void XeGPUBlockingPass::runOnOperation() {
   MLIRContext *ctx = &getContext();
   Operation *mod = getOperation();
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-instructionlize.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
similarity index 99%
rename from mlir/test/Dialect/XeGPU/xegpu-instructionlize.mlir
rename to mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 888684789cc8c..c3db6b2abb7bd 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-instructionlize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --xegpu-instructionlize -split-input-file %s | FileCheck %s
+// RUN: mlir-opt --xegpu-blocking -split-input-file %s | FileCheck %s
 
 
 #a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>

>From 061b6e00f3f0036a15790fea4e3ffd9b1def5bf4 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 16 May 2025 16:37:25 +0000
Subject: [PATCH 21/55] add 1d and 2d elemwise test

---
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 104 +++++++++++++++++---
 1 file changed, 93 insertions(+), 11 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index c3db6b2abb7bd..d8a5dfe7d4b13 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -1,13 +1,8 @@
 // RUN: mlir-opt --xegpu-blocking -split-input-file %s | FileCheck %s
 
-
 #a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
 #b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
 #c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
-
-#l1 = #xegpu.layout<inst_data = [8, 16]>
-#l2 = #xegpu.layout<inst_data = [16, 16]>
-
 gpu.module @test_kernel {
   gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
     %c0 = arith.constant 0 : index
@@ -44,9 +39,13 @@ gpu.module @test_kernel {
     xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c>
     gpu.return
   }
+}
 
-  //-----
-  gpu.func @test_gemm_simple(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+// -----
+#l1 = #xegpu.layout<inst_data = [8, 16]>
+#l2 = #xegpu.layout<inst_data = [16, 16]>
+gpu.module @test_kernel {
+  gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
     %c0 = arith.constant 0 : index
     %c16 = arith.constant 16 : index
     %c32 = arith.constant 32 : index
@@ -81,10 +80,14 @@ gpu.module @test_kernel {
     xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #l1>
     gpu.return
   }
+}
 
-  //-----
-
-  gpu.func @test_gemm_a_preop(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+// -----
+#a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+#b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
+#c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+gpu.module @test_kernel {
+  gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
     %c0 = arith.constant 0 : index
     %c16 = arith.constant 16 : index
     %c32 = arith.constant 32 : index
@@ -120,4 +123,83 @@ gpu.module @test_kernel {
     //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>>
     xegpu.store_nd %out#2, %c_tdesc: vector<16x32xf32>, !xegpu.tensor_desc<16x32xf32, #c>
     gpu.return
-  }}
+  }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [8, 16]>
+gpu.module @test_kernel {
+  gpu.func @test_elementwise(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c32 : index
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l>
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l>
+
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
+      -> (!xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>) {
+      //CHECK-COUNT-8: xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16, #l> -> vector<16x32xf16>
+
+      //CHECK-COUNT-4: arith.addf {{.*}} : vector<8x16xf16>
+      %c = arith.addf %a, %b {layout_result_0 = #l} : vector<16x32xf16>
+
+      //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+      xegpu.store_nd %c, %arg2: vector<16x32xf16>, !xegpu.tensor_desc<16x32xf16, #l>
+
+      //CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf16>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l>
+      %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c0, %c32] : !xegpu.tensor_desc<16x32xf16, #l>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc
+        : !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>, !xegpu.tensor_desc<16x32xf16, #l>
+    }
+    gpu.return
+  }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [8]>
+gpu.module @test_kernel {
+  gpu.func @test_elementwise(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c32 : index
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l>
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32xf16, #l>
+
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
+      -> (!xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>) {
+      //CHECK-COUNT-8: xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8xf16> -> vector<8xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32xf16, #l> -> vector<32xf16>
+
+      //CHECK-COUNT-4: arith.addf {{.*}} : vector<8xf16>
+      %c = arith.addf %a, %b {layout_result_0 = #l} : vector<32xf16>
+
+      //CHECK-COUNT-4: xegpu.store_nd {{.*}} : vector<8xf16>, !xegpu.tensor_desc<8xf16>
+      xegpu.store_nd %c, %arg2: vector<32xf16>, !xegpu.tensor_desc<32xf16, #l>
+
+      //CHECK-COUNT-12: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8xf16>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c32] : !xegpu.tensor_desc<32xf16, #l>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32] : !xegpu.tensor_desc<32xf16, #l>
+      %c_next_tdesc = xegpu.update_nd_offset %arg2, [%c32] : !xegpu.tensor_desc<32xf16, #l>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc
+        : !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>, !xegpu.tensor_desc<32xf16, #l>
+    }
+    gpu.return
+  }
+}

>From b3ba670e96a0cdd0afcb74953764779cdcc6fb66 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 16 May 2025 17:58:57 +0000
Subject: [PATCH 22/55] Address feedback

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  2 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 27 ++++-----
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 59 ++++++++++++++-----
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 56 ++++++++++++------
 4 files changed, 95 insertions(+), 49 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 0be9fceb25ef1..6f585f9ceb29b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -33,7 +33,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
       "Print the result of the subgroup map propagation analysis and exit.">];
 }
 
-def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute", "::mlir::gpu::GPUModuleOp"> {
+def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
   let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
   let description = [{
     This transform pass distributes the workgroup level computation to
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 7c5a6d362c3d1..20fc6951c9481 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -7,14 +7,15 @@
 //===----------------------------------------------------------------------===//
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include <mlir/Dialect/GPU/IR/GPUDialect.h>
-#include <mlir/Dialect/Index/IR/IndexOps.h>
 
 namespace mlir {
 namespace xegpu {
@@ -70,15 +71,6 @@ namespace {
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 
-  // Convert linear subgroup ID to 2D coordinates
-  // TODO: Delinearize for nD
-  SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
-                                           Location loc, Value sgID,
-                                           Value sgDimX, Value sgDimY) const {
-    return {rewriter.create<index::DivUOp>(loc, sgID, sgDimY),
-            rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
-  }
-
   // Calculate offset for each subgroup
   SmallVector<OpFoldResult>
   calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
@@ -144,7 +136,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
 
     // TODO : Handle order attribute
     // Get the subgroup ID
-    auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
+    auto linearSgId =
+        rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
 
     // Create constants for layout dimensions
     SmallVector<Value> sgLayoutDim(sgLayout.size());
@@ -156,9 +149,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
     }
 
-    // Delinearize the 1D subgroup id into 2d
-    SmallVector<Value> sgIds = delinearizeSubgroupId(
-        rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
+    auto deLinearizeSgId =
+        affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
+    if (failed(deLinearizeSgId))
+      return failure();
+    SmallVector<Value> sgIds = *deLinearizeSgId;
 
     // Calculate distribution unit shape and local offsets for subgroup
     SmallVector<int64_t> distUnitShape(sgLayout.size());
@@ -267,9 +262,9 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
     if (!originalLayout)
       return failure();
 
+    size_t i = 0;
     SmallVector<Value> newDpasOps;
     for (auto aVec : adaptor.getLhs()) {
-      size_t i = 0;
       for (auto bVec : adaptor.getRhs()) {
 
         llvm::SmallVector<Value> operands({aVec, bVec});
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 321cc0510a24c..23fdffc220ecb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -6,7 +6,9 @@ gpu.module @test_round_robin_assignment {
   gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
       // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
       // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-NOT: xegpu.create_nd_tdesc
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+        -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       gpu.return
     }
 
@@ -17,18 +19,26 @@ gpu.module @test_round_robin_assignment {
       // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
       // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
       // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
-      %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+      // CHECK-NOT: xegpu.load_nd
+      %load =  xegpu.load_nd %tdesc
+        : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+        -> vector<24x32xf32>
       gpu.return
     }
 
   // CHECK-LABEL: test_store_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_store_nd(%src: memref<24x32xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+        -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
       // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-      %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
-      xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-NOT : xegpu.store_nd
+      %load = xegpu.load_nd %tdesc
+        : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+        -> vector<24x32xf32>
+      xegpu.store_nd %load, %tdesc
+        : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       gpu.return
   }
 
@@ -38,7 +48,9 @@ gpu.module @test_round_robin_assignment {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
     // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    %update = xegpu.update_nd_offset %tdesc, [0, 16] :  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.update_nd_offset
+    %update = xegpu.update_nd_offset %tdesc, [0, 16]
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
 
@@ -47,28 +59,45 @@ gpu.module @test_round_robin_assignment {
   gpu.func @test_dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
     // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
     // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
     // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
     // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
     // CHECK-COUNT-4:  xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
     // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
     // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
     // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
     // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %load_b =  xegpu.load_nd %tdesc_b:  !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
-    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
+    // CHECK-NOT: xegpu.dpas
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32>
+      -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      -> vector<8x8xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32>
+      -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      -> vector<8x8xf32>
+    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
+      -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
     gpu.return
   }
 
   // CHECK-LABEL: test_prefetch_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
-    // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}}
+    // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.prefetch_nd
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 3bd95ee775db3..5feb0da1ddfae 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
+//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
+//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
 gpu.module @test_1_1_assignment {
   // CHECK-LABEL: test_create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
@@ -8,8 +10,8 @@ gpu.module @test_1_1_assignment {
   // CHECK: %[[C12:.*]] = arith.constant 12 : index
   // CHECK: %[[C4:.*]] = arith.constant 4 : index
   // CHECK: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK: %[[DIV:.*]] = index.divu %[[SGID]], %[[C4]]
-  // CHECK: %[[REM:.*]] = index.remu %[[SGID]], %[[C4]]
+  // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
+  // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
   // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
   // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
@@ -18,7 +20,8 @@ gpu.module @test_1_1_assignment {
   // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
   // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
   // CHECK: gpu.return
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+    -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   gpu.return
   }
 
@@ -30,8 +33,11 @@ gpu.module @test_1_1_assignment {
     // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
     // CHECK-SAME: -> vector<12x8xf32>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-    %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
     gpu.return
   }
 
@@ -45,9 +51,13 @@ gpu.module @test_1_1_assignment {
     // CHECK-SAME: -> vector<12x8xf32>
     // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
     // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-    %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
-    xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load = xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    xegpu.store_nd %load, %tdesc
+      : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     gpu.return
 }
 
@@ -58,8 +68,10 @@ gpu.func @test_update_nd(%src: memref<24x32xf32>){
   // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
   // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
   // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-  %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+    -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  %update = xegpu.update_nd_offset %tdesc, [0, 16]
+    : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   gpu.return
 }
 
@@ -80,11 +92,19 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
     // CHECK-SAME: {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
-    %load_b =  xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
-    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
+      -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+      -> vector<32x24xf32>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
+      : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
 
@@ -95,8 +115,10 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
     // CHECK: xegpu.prefetch_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-    xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     gpu.return
   }
 }

>From 61b003c618e12634aa2046097d627c40065b5cf7 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 16 May 2025 20:44:50 +0000
Subject: [PATCH 23/55] stage

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 41 ++++++++++++++++---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 28 +++++++------
 2 files changed, 52 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 7c5a6d362c3d1..1db479cb7af66 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include <mlir/Dialect/GPU/IR/GPUDialect.h>
@@ -163,8 +164,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     // Calculate distribution unit shape and local offsets for subgroup
     SmallVector<int64_t> distUnitShape(sgLayout.size());
     SmallVector<Value> localOffset(sgLayout.size());
+    ArrayRef<int64_t> shape = tdescTy.getShape();
     for (size_t i = 0; i < sgLayout.size(); i++) {
       distUnitShape[i] = sgLayout[i] * sgShape[i];
+      if (distUnitShape[i] > shape[i])
+        distUnitShape[i] = shape[i];
       localOffset[i] =
           rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
     }
@@ -263,7 +267,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
       return failure();
 
     auto originalLayout =
-        llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+        llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout_result_0"));
     if (!originalLayout)
       return failure();
 
@@ -311,14 +315,38 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+struct UnrealizedConversionCastOpPattern
+    : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
+  using OpConversionPattern<
+      mlir::UnrealizedConversionCastOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::UnrealizedConversionCastOp op,
+                  OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (op.getNumOperands() == 1 && op.getNumResults() == 1) {
+      rewriter.replaceOpWithMultiple(op, flattenValues(adaptor.getInputs()));
+      return mlir::success();
+    }
+    return mlir::failure();
+  }
+};
+
 } // namespace
 
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
-      patterns.getContext());
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+               UnrealizedConversionCastOpPattern>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -350,7 +378,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   };
 
   auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
-    return !layout || layout.getSgLayout() == nullptr;
+    return !layout || !layout.isWgLayout();
   };
 
   target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
@@ -362,12 +390,15 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   });
 
   target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
-    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout_result_0"));
     return isLegal(layout);
   });
 
+  target.addIllegalOp<UnrealizedConversionCastOp>();
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
+  xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation());
+
   xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 14b2b909e143a..40a122f145761 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -307,6 +307,9 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
 
   { // perform the conversion from RankedTensorType to VectorType based on the
     // LayoutAttr
+    llvm::dbgs() << "\n\nDumpBefore: \n";
+    op->dump();
+    llvm::dbgs() << "\n";
     auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape,
                                         DenseI32ArrayAttr sgDataAttr,
                                         DenseI32ArrayAttr sgLayoutAttr) {
@@ -319,6 +322,10 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
       assert(tileShape.size() && "failed to compute tileShape");
       SmallVector<int64_t> distUnit =
           computeElementwiseMul(sgLayout, tileShape);
+      for(size_t i = 0; i < distUnit.size(); i++) {
+        if (distUnit[i] > shape[i])
+          distUnit[i] = shape[i];
+      }
       int count = computeProduct(shape) / computeProduct(distUnit);
       return std::make_pair(tileShape, count);
     };
@@ -328,6 +335,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
     converter.addConversion(
         [&](RankedTensorType type,
             SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+          llvm::dbgs() << "\n\nConverting Type: " << type;
           ArrayRef<int64_t> shape = type.getShape();
           auto encoding = type.getEncoding();
           Type elemTy = type.getElementType();
@@ -344,13 +352,10 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
               // shape/sgLayout
               std::tie(subShape, count) = computeTileShapeAndCount(
                   shape, layout.getSgData(), layout.getSgLayout());
-            } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
-              // for unrolling, the subShape is determined by inst_data
-              subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
-              count = computeProduct(shape) / computeProduct(subShape);
             }
           }
           auto newTy = VectorType::get(subShape, elemTy);
+          llvm::dbgs() << "\n   result: " << count << ", " << newTy << "\n";
           result.append(count, newTy);
           return success();
         });
@@ -358,6 +363,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
     converter.addConversion(
         [&](xegpu::TensorDescType type,
             SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+          llvm::dbgs() << "\n\nConverting Type: " << type;
           MLIRContext *ctx = type.getContext();
           Type elemTy = type.getElementType();
           Attribute encoding = type.getEncoding();
@@ -376,17 +382,12 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
               std::tie(subShape, count) = computeTileShapeAndCount(
                   shape, layout.getSgData(), layout.getSgLayout());
               layout = layout.dropSgLayoutAndData();
-            } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
-              // for unrolling, the subShape is determined by inst_data
-              subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
-              count = computeProduct(shape) / computeProduct(subShape);
-              layout = layout.dropInstData();
-            }
-
-            newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding,
+              newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding,
                                                layout);
+            }
           }
 
+          llvm::dbgs() << "\n   result: " << count << ", " << newTy << "\n";
           result.append(count, newTy);
           return success();
         });
@@ -449,5 +450,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
     (void)mlir::applyPartialConversion(op, target, std::move(patterns));
+  llvm::dbgs() << "\n\nDumpAfter: \n";
+  op->dump();
+  llvm::dbgs() << "\n";
   }
 }

>From 64259613115e79ec92f8cc717a42ccc3d0a94b70 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 16 May 2025 20:38:14 +0000
Subject: [PATCH 24/55] add support for nD

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 39 +++++++------------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  2 +-
 2 files changed, 16 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 20fc6951c9481..68410f0f443f8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -77,19 +77,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
                          const SmallVector<OpFoldResult> &originalOffsets,
                          const SmallVector<Value> &localOffset,
                          const SmallVector<int64_t> &distUnitBaseAddr) const {
-
-    Value constOffsetX =
-        rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[0]);
-    Value constOffsetY =
-        rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[1]);
-
-    Value offsetX =
-        rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
-    Value offsetY =
-        rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
-
-    size_t lastDimIndex = originalOffsets.size() - 1;
-    size_t secondLastDimIndex = lastDimIndex - 1;
+    assert(localOffset.size() == distUnitBaseAddr.size() &&
+           "localOffset and distUnitBaseAddr must have the same rank");
 
     // Convert originalOffsets to Value
     auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
@@ -102,18 +91,20 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       llvm_unreachable("Unsupported OpFoldResult kind");
     };
 
-    Value origOffsetX =
-        getValueFromOpFoldResult(originalOffsets[secondLastDimIndex]);
-    Value origOffsetY = getValueFromOpFoldResult(originalOffsets[lastDimIndex]);
-    Value globalOffsetX =
-        rewriter.createOrFold<index::AddOp>(loc, origOffsetX, offsetX);
-    Value globalOffsetY =
-        rewriter.createOrFold<index::AddOp>(loc, origOffsetY, offsetY);
-
     SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
                                             originalOffsets.end());
-    globalOffsets[secondLastDimIndex] = globalOffsetX;
-    globalOffsets[lastDimIndex] = globalOffsetY;
+    size_t rank = localOffset.size();
+    for (size_t i = 0; i < rank; ++i) {
+      size_t dimIdx = originalOffsets.size() - rank + i;
+      Value constOffset =
+          rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
+      Value offset =
+          rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
+      Value origOffset = getValueFromOpFoldResult(originalOffsets[dimIdx]);
+      Value globalOffset =
+          rewriter.createOrFold<index::AddOp>(loc, origOffset, offset);
+      globalOffsets[dimIdx] = globalOffset;
+    }
 
     return globalOffsets;
   }
@@ -283,7 +274,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
         tmpC = rewriter.create<xegpu::DpasOp>(
             loc, resTy, operands,
             llvm::ArrayRef<NamedAttribute>(
-                {"layout", originalLayout.dropSgLayoutAndData()}));
+                {"layout_result_0", originalLayout.dropSgLayoutAndData()}));
         newDpasOps.push_back(tmpC);
       }
     }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 5feb0da1ddfae..5d9ddb3ef1e97 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -90,7 +90,7 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
     // CHECK-SAME: -> vector<8x12xf32>
     // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
-    // CHECK-SAME: {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+    // CHECK-SAME: {layout_result_0 =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
       -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>

>From 387ac9310f2ed10260f80be7c7d8c73ac529695c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 16 May 2025 22:40:43 +0000
Subject: [PATCH 25/55] refactor

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  11 +-
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  59 +++++++-
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 126 ++++--------------
 3 files changed, 88 insertions(+), 108 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index b41da0ea6a276..44faef00a739e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -17,6 +17,7 @@ class OpOperand;
 class OpResult;
 class OpBuilder;
 class ValueRange;
+class TypeConverter;
 
 namespace xegpu {
 class LayoutAttr;
@@ -96,10 +97,12 @@ Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
                                       ValueRange values,
                                       ArrayRef<int64_t> shape);
 
-/// Do type conversion for SCF structural ops, e.g., scf.for. Since VectorType
-/// cannot carry the layout attribute, they are converted into RankedTensorType
-/// first, which will convert back to VectorType in the second round.
-void doSCFStructuralTypeConversionWithTensorType(Operation *op);
+/// Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type
+/// convertion patterns. Since VectorType cannot carry the layout attribute, which is
+/// needed to guide the type conversion for XeGPU, they are first converted into
+/// RankedTensorType, where the layout attribute can be attached. And then upstream
+/// SCF structural type conversion patterns are applied with the provided converter.
+void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter);
 
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 1587cbdfed2cc..d0adb860abca7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -207,7 +208,63 @@ void XeGPUBlockingPass::runOnOperation() {
   xegpu::setLayoutAttrs(mod, [&](Value v) { return xegpu::getLayoutAttr(v); });
 
   // Perform type conversion for SCF control folow ops
-  xegpu::doSCFStructuralTypeConversionWithTensorType(mod);
+  TypeConverter converter;
+  converter.addConversion([&](Type type) -> Type { return type; });
+  converter.addConversion(
+      [&](RankedTensorType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        Type elemTy = type.getElementType();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        // init count and subShape to the default value. If the LayoutAttr
+        // is not present, it will return a VectorType with original shape.
+        int count = 1;
+        SmallVector<int64_t> subShape(shape);
+        if (auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding())) {
+          if (layout.isWgLayout())
+            return failure();
+          if (DenseI32ArrayAttr instData = layout.getInstData()) {
+            // for unrolling, the subShape is determined by inst_data
+            subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
+            count = computeProduct(shape) / computeProduct(subShape);
+          }
+        }
+        auto newTy = VectorType::get(subShape, elemTy);
+        result.append(count, newTy);
+        return success();
+      });
+
+  converter.addConversion(
+      [&](xegpu::TensorDescType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        MLIRContext *ctx = type.getContext();
+        Type elemTy = type.getElementType();
+        Attribute encoding = type.getEncoding();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        // init count and newTy to the default value. If the layout attribute
+        // is not present, it will return the original type.
+        int count = 1;
+        SmallVector<int64_t> subShape(shape);
+
+        xegpu::LayoutAttr layout = type.getLayoutAttr();
+
+        if (layout) {
+          if (layout.isWgLayout())
+            return failure();
+
+          if (DenseI32ArrayAttr instData = layout.getInstData()) {
+            // for unrolling, the subShape is determined by inst_data
+            subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
+            count = computeProduct(shape) / computeProduct(subShape);
+            layout = layout.dropInstData();
+          }
+        }
+        auto newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding, layout);
+        result.append(count, newTy);
+        return success();
+      });
+  xegpu::doSCFStructuralTypeConversionWithTensorType(mod, converter);
 
   xegpu::UnrollOptions options;
   options.setFilterConstraint([&](Operation *op) -> LogicalResult {
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 14b2b909e143a..ed7d2eeb6807b 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -225,7 +225,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
   return result;
 }
 
-void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
+void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter) {
   MLIRContext *context = op->getContext();
 
   auto materializeCast = [&](OpBuilder &builder, Type type, ValueRange inputs,
@@ -307,109 +307,11 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
 
   { // perform the conversion from RankedTensorType to VectorType based on the
     // LayoutAttr
-    auto computeTileShapeAndCount = [&](ArrayRef<int64_t> shape,
-                                        DenseI32ArrayAttr sgDataAttr,
-                                        DenseI32ArrayAttr sgLayoutAttr) {
-      SmallVector<int64_t> tileShape;
-      auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-      if (sgDataAttr)
-        tileShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
-      else
-        tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape);
-      assert(tileShape.size() && "failed to compute tileShape");
-      SmallVector<int64_t> distUnit =
-          computeElementwiseMul(sgLayout, tileShape);
-      int count = computeProduct(shape) / computeProduct(distUnit);
-      return std::make_pair(tileShape, count);
-    };
-
-    TypeConverter converter;
-    converter.addConversion([&](Type type) -> Type { return type; });
-    converter.addConversion(
-        [&](RankedTensorType type,
-            SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
-          ArrayRef<int64_t> shape = type.getShape();
-          auto encoding = type.getEncoding();
-          Type elemTy = type.getElementType();
-
-          // init count and subShape to the default value. If the LayoutAttr
-          // is not present, it will return a VectorType with original shape.
-          int count = 1;
-          SmallVector<int64_t> subShape(shape);
-
-          if (auto layout =
-                  llvm::dyn_cast_if_present<xegpu::LayoutAttr>(encoding)) {
-            if (layout.isWgLayout()) {
-              // for WgToSg, the subShape is either from sgData or computed as
-              // shape/sgLayout
-              std::tie(subShape, count) = computeTileShapeAndCount(
-                  shape, layout.getSgData(), layout.getSgLayout());
-            } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
-              // for unrolling, the subShape is determined by inst_data
-              subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
-              count = computeProduct(shape) / computeProduct(subShape);
-            }
-          }
-          auto newTy = VectorType::get(subShape, elemTy);
-          result.append(count, newTy);
-          return success();
-        });
-
-    converter.addConversion(
-        [&](xegpu::TensorDescType type,
-            SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
-          MLIRContext *ctx = type.getContext();
-          Type elemTy = type.getElementType();
-          Attribute encoding = type.getEncoding();
-          ArrayRef<int64_t> shape = type.getShape();
-
-          // init count and newTy to the default value. If the layout attribute
-          // is not present, it will return the original type.
-          int count = 1;
-          Type newTy = type;
-
-          if (xegpu::LayoutAttr layout = type.getLayoutAttr()) {
-            SmallVector<int64_t> subShape(shape);
-            if (layout.isWgLayout()) {
-              // for WgToSg, the subShape is either from sgData or computed as
-              // shape/sgLayout
-              std::tie(subShape, count) = computeTileShapeAndCount(
-                  shape, layout.getSgData(), layout.getSgLayout());
-              layout = layout.dropSgLayoutAndData();
-            } else if (DenseI32ArrayAttr instData = layout.getInstData()) {
-              // for unrolling, the subShape is determined by inst_data
-              subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
-              count = computeProduct(shape) / computeProduct(subShape);
-              layout = layout.dropInstData();
-            }
-
-            newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding,
-                                               layout);
-          }
-
-          result.append(count, newTy);
-          return success();
-        });
-
-    converter.addSourceMaterialization(materializeCast);
-    converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
-                                           ValueRange inputs, Location loc) {
-      return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
-          .getResults();
-    });
-
-    mlir::ConversionTarget target(*context);
-    target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
-        [&](UnrealizedConversionCastOp op) {
-          auto isTensorTy = [&](Type type) {
-            return isa<RankedTensorType>(type);
-          };
-          if (llvm::any_of(op->getOperandTypes(), isTensorTy) ||
-              llvm::any_of(op->getResultTypes(), isTensorTy))
-            return false;
-          return true;
-        });
 
+    // Handle the UnrealizedConversionCastOp introduced by the first step.
+    // For vector->RankedTensorType, it will simply forward the inputs.
+    // For RankedTensorType->vector, it will update the inputs with the
+    // one from the adaptor.
     class UnrealizedConversionCastOpPattern
         : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
       using OpConversionPattern<
@@ -444,6 +346,24 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
       }
     };
 
+    converter.addSourceMaterialization(materializeCast);
+    converter.addTargetMaterialization([&](OpBuilder &builder, TypeRange type,
+                                           ValueRange inputs, Location loc) {
+      return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+          .getResults();
+    });
+
+    mlir::ConversionTarget target(*context);
+    target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
+        [&](UnrealizedConversionCastOp op) {
+          auto isTensorTy = [&](Type type) {
+            return isa<RankedTensorType>(type);
+          };
+          if (llvm::any_of(op->getOperandTypes(), isTensorTy) ||
+              llvm::any_of(op->getResultTypes(), isTensorTy))
+            return false;
+          return true;
+        });
     mlir::RewritePatternSet patterns(context);
     patterns.insert<UnrealizedConversionCastOpPattern>(context);
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,

>From ebd78aedf4859179b417056a0c7f9bfcf5ab2968 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 16 May 2025 23:27:56 +0000
Subject: [PATCH 26/55] fix naming issue

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index d0adb860abca7..4b6a03c8716c0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -1,4 +1,4 @@
-//===---- XeGPUBlocking.cpp ---- XeGPU Instructionlize Pass ---------------===//
+//===---- XeGPUBlocking.cpp ---- XeGPU Blocking Pass ----------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -242,8 +242,8 @@ void XeGPUBlockingPass::runOnOperation() {
         Attribute encoding = type.getEncoding();
         ArrayRef<int64_t> shape = type.getShape();
 
-        // init count and newTy to the default value. If the layout attribute
-        // is not present, it will return the original type.
+        // init count and newTy to the default value. If the layout
+        // attribute is not present, it will return the original type.
         int count = 1;
         SmallVector<int64_t> subShape(shape);
 

>From bbf4796df3f0e80dbaeeac380ab998bbb5cdf76e Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 16 May 2025 23:28:33 +0000
Subject: [PATCH 27/55] fix format

---
 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 14 ++++++++------
 .../lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp |  6 ++++--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp        |  3 ++-
 3 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 44faef00a739e..b8e5fe5cbde32 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -97,12 +97,14 @@ Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
                                       ValueRange values,
                                       ArrayRef<int64_t> shape);
 
-/// Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type
-/// convertion patterns. Since VectorType cannot carry the layout attribute, which is
-/// needed to guide the type conversion for XeGPU, they are first converted into
-/// RankedTensorType, where the layout attribute can be attached. And then upstream
-/// SCF structural type conversion patterns are applied with the provided converter.
-void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter);
+/// Do type conversion for SCF structural ops, e.g., scf.for using SCF structure
+/// type convertion patterns. Since VectorType cannot carry the layout
+/// attribute, which is needed to guide the type conversion for XeGPU, they are
+/// first converted into RankedTensorType, where the layout attribute can be
+/// attached. And then upstream SCF structural type conversion patterns are
+/// applied with the provided converter.
+void doSCFStructuralTypeConversionWithTensorType(Operation *op,
+                                                 TypeConverter converter);
 
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 4b6a03c8716c0..19ff4bf992b07 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -220,7 +220,8 @@ void XeGPUBlockingPass::runOnOperation() {
         // is not present, it will return a VectorType with original shape.
         int count = 1;
         SmallVector<int64_t> subShape(shape);
-        if (auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding())) {
+        if (auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
+                type.getEncoding())) {
           if (layout.isWgLayout())
             return failure();
           if (DenseI32ArrayAttr instData = layout.getInstData()) {
@@ -260,7 +261,8 @@ void XeGPUBlockingPass::runOnOperation() {
             layout = layout.dropInstData();
           }
         }
-        auto newTy = xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding, layout);
+        auto newTy =
+            xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding, layout);
         result.append(count, newTy);
         return success();
       });
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index ed7d2eeb6807b..5e0e83ef2eed5 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -225,7 +225,8 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
   return result;
 }
 
-void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter) {
+void xegpu::doSCFStructuralTypeConversionWithTensorType(
+    Operation *op, TypeConverter converter) {
   MLIRContext *context = op->getContext();
 
   auto materializeCast = [&](OpBuilder &builder, Type type, ValueRange inputs,

>From 3807eeaf672c17b77b2b2fe8733709aab3f52842 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 19 May 2025 16:06:03 +0000
Subject: [PATCH 28/55] fix overflow

---
 mlir/lib/Dialect/Utils/IndexingUtils.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index d9edabef6693d..8de77e2c3cb08 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -24,7 +24,7 @@ SmallVector<ExprType> computeSuffixProductImpl(ArrayRef<ExprType> sizes,
   if (sizes.empty())
     return {};
   SmallVector<ExprType> strides(sizes.size(), unit);
-  for (int64_t r = strides.size() - 2; r >= 0; --r)
+  for (int64_t r = static_cast<int64_t>(strides.size()) - 2; r >= 0; --r)
     strides[r] = strides[r + 1] * sizes[r + 1];
   return strides;
 }

>From 848e5a6fe8042522034605651db36666bb6375c8 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 19 May 2025 17:40:42 +0000
Subject: [PATCH 29/55] trunc disUnitShape

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 68410f0f443f8..79b08f065ed2c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -150,7 +150,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     SmallVector<int64_t> distUnitShape(sgLayout.size());
     SmallVector<Value> localOffset(sgLayout.size());
     for (size_t i = 0; i < sgLayout.size(); i++) {
-      distUnitShape[i] = sgLayout[i] * sgShape[i];
+      distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
       localOffset[i] =
           rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
     }

>From 012e44bb5002774dc49c408855bd32098d862afa Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 19 May 2025 19:31:40 +0000
Subject: [PATCH 30/55] add scf control op

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 69 ++++++++++++++++++-
 1 file changed, 66 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 90a7ff0308b68..2b60be959dedb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -29,6 +29,28 @@ using namespace mlir;
 
 namespace {
 
+static std::pair<SmallVector<int64_t>, int> computeTileShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+  // init count and subShape to the default value. If the LayoutAttr
+  // is not present, it will return a VectorType with original shape.
+  int count = 1;
+  SmallVector<int64_t> tileShape(shape);
+
+  if (layout) {
+    if (DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout()) {
+      auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+      if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
+        tileShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+      else
+        tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape);
+      SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, tileShape);
+      for (size_t i = 0; i < distUnit.size(); ++i)
+        distUnit[i] = std::min(shape[i], distUnit[i]);
+      count = computeProduct(shape) / computeProduct(distUnit);
+    }
+  }
+  return std::make_pair(tileShape, count);
+}
+
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -141,7 +163,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
     }
 
-    auto deLinearizeSgId =
+    FailureOr<SmallVector<Value>> deLinearizeSgId =
         affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
     if (failed(deLinearizeSgId))
       return failure();
@@ -150,7 +172,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     // Calculate distribution unit shape and local offsets for subgroup
     SmallVector<int64_t> distUnitShape(sgLayout.size());
     SmallVector<Value> localOffset(sgLayout.size());
-    ArrayRef<int64_t> shape = tdescTy.getShape();
     for (size_t i = 0; i < sgLayout.size(); i++) {
       distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
       localOffset[i] =
@@ -316,6 +337,8 @@ struct UnrealizedConversionCastOpPattern
                   OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (op.getNumOperands() == 1 && op.getNumResults() == 1) {
+      llvm::dbgs() << "\n\nUnrealizedConversionCastOp: " << op
+                   << "\n is replaced with: " << flattenValues(adaptor.getInputs())[0] << "\n";
       rewriter.replaceOpWithMultiple(op, flattenValues(adaptor.getInputs()));
       return mlir::success();
     }
@@ -379,9 +402,49 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   });
 
   target.addIllegalOp<UnrealizedConversionCastOp>();
+
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
-  xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation());
+  TypeConverter converter;
+  converter.addConversion([&](Type type) -> Type { return type; });
+  converter.addConversion(
+      [&](RankedTensorType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        Type elemTy = type.getElementType();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        int count;
+        SmallVector<int64_t> subShape;
+        std::tie(subShape, count) =
+            computeTileShapeAndCount(shape, dyn_cast<xegpu::LayoutAttr>(type.getEncoding()));
+
+        auto newTy = VectorType::get(subShape, elemTy);
+        result.append(count, newTy);
+        return success();
+      });
+
+  converter.addConversion(
+      [&](xegpu::TensorDescType type,
+          SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        Type elemTy = type.getElementType();
+        ArrayRef<int64_t> shape = type.getShape();
+
+        // init count and newTy to the default value. If the layout
+        // attribute is not present, it will return the original type.
+        int count;
+        SmallVector<int64_t> subShape;
+        xegpu::LayoutAttr layout = type.getLayoutAttr();
+        std::tie(subShape, count) = computeTileShapeAndCount(shape, layout);
+
+        if (layout)
+          layout = layout.dropSgLayoutAndData();
+
+        auto newTy = xegpu::TensorDescType::get(type.getContext(), subShape, elemTy, type.getEncoding(), layout);
+        result.append(count, newTy);
+        return success();
+      });
+
+  xegpu::doSCFStructuralTypeConversionWithTensorType(getOperation(), converter);
 
   xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
   if (failed(

>From c4d5183a5faadefba46f4bb5be6f7ad92d5b89af Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 19 May 2025 21:14:40 +0000
Subject: [PATCH 31/55] fix format

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  3 ++
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 40 +++++++------------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 18 ++++-----
 3 files changed, 25 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index b8e5fe5cbde32..4b3de2ddc600d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -79,6 +79,9 @@ void setLayoutAttr(OpOperand &opr, LayoutAttr layout);
 /// Set the LayoutAttr for the given OpResult by attching it to the defining op
 void setLayoutAttr(OpResult result, LayoutAttr layout);
 
+/// Flatten a set of ValueRange into a single SmallVector<Value>
+SmallVector<Value> flattenValues(ArrayRef<ValueRange> values);
+
 /// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
 /// If the operation contains regions, it is also applied recursively to the
 /// contained operations
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 2b60be959dedb..85b1ba78cb9e5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -14,8 +14,8 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
-#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -29,7 +29,8 @@ using namespace mlir;
 
 namespace {
 
-static std::pair<SmallVector<int64_t>, int> computeTileShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+static std::pair<SmallVector<int64_t>, int>
+computeTileShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   // init count and subShape to the default value. If the LayoutAttr
   // is not present, it will return a VectorType with original shape.
   int count = 1;
@@ -42,7 +43,8 @@ static std::pair<SmallVector<int64_t>, int> computeTileShapeAndCount(ArrayRef<in
         tileShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
       else
         tileShape = computeShapeRatio(shape, sgLayout).value_or(tileShape);
-      SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, tileShape);
+      SmallVector<int64_t> distUnit =
+          computeElementwiseMul(sgLayout, tileShape);
       for (size_t i = 0; i < distUnit.size(); ++i)
         distUnit[i] = std::min(shape[i], distUnit[i]);
       count = computeProduct(shape) / computeProduct(distUnit);
@@ -271,8 +273,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
     if (resultTy.getRank() != 2)
       return failure();
 
-    auto originalLayout =
-        llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout_result_0"));
+    auto originalLayout = xegpu::getLayoutAttr(op.getResult());
     if (!originalLayout)
       return failure();
 
@@ -294,10 +295,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
             llvm::cast<VectorType>(bVec.getType()).getShape();
         VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
                                            resultTy.getElementType());
-        tmpC = rewriter.create<xegpu::DpasOp>(
-            loc, resTy, operands,
-            llvm::ArrayRef<NamedAttribute>(
-                {"layout_result_0", originalLayout.dropSgLayoutAndData()}));
+        tmpC = rewriter.create<xegpu::DpasOp>(loc, resTy, operands);
+        xegpu::setLayoutAttr(cast<OpResult>(tmpC), originalLayout.dropSgLayoutAndData());
         newDpasOps.push_back(tmpC);
       }
     }
@@ -320,26 +319,16 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
-static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
-  SmallVector<Value> result;
-  for (const auto &vals : values)
-    llvm::append_range(result, vals);
-  return result;
-}
-
 struct UnrealizedConversionCastOpPattern
     : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
   using OpConversionPattern<
       mlir::UnrealizedConversionCastOp>::OpConversionPattern;
 
   mlir::LogicalResult
-  matchAndRewrite(mlir::UnrealizedConversionCastOp op,
-                  OneToNOpAdaptor adaptor,
+  matchAndRewrite(mlir::UnrealizedConversionCastOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (op.getNumOperands() == 1 && op.getNumResults() == 1) {
-      llvm::dbgs() << "\n\nUnrealizedConversionCastOp: " << op
-                   << "\n is replaced with: " << flattenValues(adaptor.getInputs())[0] << "\n";
-      rewriter.replaceOpWithMultiple(op, flattenValues(adaptor.getInputs()));
+      rewriter.replaceOpWithMultiple(op, xegpu::flattenValues(adaptor.getInputs()));
       return mlir::success();
     }
     return mlir::failure();
@@ -397,7 +386,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   });
 
   target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
-    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout_result_0"));
+    auto layout = xegpu::getLayoutAttr(op.getResult());
     return isLegal(layout);
   });
 
@@ -415,8 +404,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
         int count;
         SmallVector<int64_t> subShape;
-        std::tie(subShape, count) =
-            computeTileShapeAndCount(shape, dyn_cast<xegpu::LayoutAttr>(type.getEncoding()));
+        std::tie(subShape, count) = computeTileShapeAndCount(
+            shape, dyn_cast<xegpu::LayoutAttr>(type.getEncoding()));
 
         auto newTy = VectorType::get(subShape, elemTy);
         result.append(count, newTy);
@@ -439,7 +428,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         if (layout)
           layout = layout.dropSgLayoutAndData();
 
-        auto newTy = xegpu::TensorDescType::get(type.getContext(), subShape, elemTy, type.getEncoding(), layout);
+        auto newTy = xegpu::TensorDescType::get(
+            type.getContext(), subShape, elemTy, type.getEncoding(), layout);
         result.append(count, newTy);
         return success();
       });
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 4724b14f37bdb..c0e20ef20c593 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -26,14 +26,6 @@
 
 using namespace mlir;
 
-/// convert ArrayRef<ValueRange> into SmallVector<Value>
-static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
-  SmallVector<Value> result;
-  for (const auto &vals : values)
-    llvm::append_range(result, vals);
-  return result;
-}
-
 FailureOr<VectorType>
 mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
   auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
@@ -179,6 +171,13 @@ void xegpu::setLayoutAttrs(Operation *mod,
   });
 }
 
+SmallVector<Value> xegpu::flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
 SmallVector<Value>
 xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
                                         Value value, ArrayRef<int64_t> shape) {
@@ -370,8 +369,5 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
     scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
                                                          target);
     (void)mlir::applyPartialConversion(op, target, std::move(patterns));
-  llvm::dbgs() << "\n\nDumpAfter: \n";
-  op->dump();
-  llvm::dbgs() << "\n";
   }
 }

>From c6695d99ab557c97269406ffe0a77d0feeb99b2b Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 19 May 2025 21:15:56 +0000
Subject: [PATCH 32/55] add comments

---
 mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td | 2 +-
 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h   | 2 ++
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp  | 7 ++++++-
 3 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index b3883605b74f2..7baa880c6ff08 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -39,7 +39,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
 }
 
 def XeGPUBlocking: Pass<"xegpu-blocking"> {
-  let summary = "Instructionlize XeGPU ops";
+  let summary = "Block XeGPU ops into smaller size.";
   let description = [{
     The pass unrolls XeGPU ops working on large shapes into ops working on small shapes
     (given by the inst_data in the layout attr), such that each of them can be dispatch
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index b8e5fe5cbde32..4077de593b109 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -103,6 +103,8 @@ Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
 /// first converted into RankedTensorType, where the layout attribute can be
 /// attached. And then upstream SCF structural type conversion patterns are
 /// applied with the provided converter.
+/// TODO: This is a temporary solution. We should refactor it when context-aware
+/// type conversion is available.
 void doSCFStructuralTypeConversionWithTensorType(Operation *op,
                                                  TypeConverter converter);
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 19ff4bf992b07..778ab0476b312 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -33,7 +33,12 @@ using namespace mlir;
 
 namespace {
 
-void resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
+// reslove the unrealized conversion cast ops generated when doing SCF
+// Structural Type Conversion. It will have two formats, N:1 vector
+// cast and 1:N vector cast. vector::insert_strided_slice ops will be
+// used for the first case, and vector::extract_strided_slice ops will be
+// used for the second case.
+static void resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
   ValueRange inputs = castOp.getInputs();
   ValueRange outputs = castOp.getOutputs();
 

>From 50e33ff069acc9e706f51ed814e1bc9961161f75 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 20 May 2025 14:19:55 +0000
Subject: [PATCH 33/55] add dbg log

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 778ab0476b312..6ac66ce7e6988 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -28,6 +28,7 @@ namespace xegpu {
 
 #define DEBUG_TYPE "xegpu-blocking"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 using namespace mlir;
 
@@ -121,6 +122,7 @@ XeGPUBlockingPass::getTileShape(OpOperand &operand) const {
     if (auto type = dyn_cast<ShapedType>(operand.get().getType()))
       return llvm::to_vector(type.getShape());
   }
+  LDBG("failed to getTileShape for operand: " << operand.get());
   return std::nullopt;
 }
 
@@ -134,6 +136,7 @@ XeGPUBlockingPass::getTileShape(OpResult result) const {
     if (auto type = dyn_cast<ShapedType>(result.getType()))
       return llvm::to_vector(type.getShape());
   }
+  LDBG("failed to getTileShape for result: " << result);
   return std::nullopt;
 }
 

>From ae22f2796b3da2267c1be06a9fdffc7466c92027 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 20 May 2025 14:20:29 +0000
Subject: [PATCH 34/55] fix format

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 6ac66ce7e6988..5bde40449b926 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -39,7 +39,8 @@ namespace {
 // cast and 1:N vector cast. vector::insert_strided_slice ops will be
 // used for the first case, and vector::extract_strided_slice ops will be
 // used for the second case.
-static void resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
+static void
+resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
   ValueRange inputs = castOp.getInputs();
   ValueRange outputs = castOp.getOutputs();
 

>From 977685060a9b2ca8df3b648c49ce946609e571d8 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 20 May 2025 14:29:13 +0000
Subject: [PATCH 35/55] cleanup

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 5bde40449b926..b4ff5856b0b6c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -188,20 +188,20 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
   for (auto &opr : op->getOpOperands()) {
     std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
     auto shapedType = dyn_cast<ShapedType>(opr.get().getType());
-    if (!shapedType)
+    if (!shapedType || !tileShape)
       continue;
 
-    if (tileShape && !llvm::equal(*tileShape, shapedType.getShape()))
+    if (!llvm::equal(*tileShape, shapedType.getShape()))
       return true;
   }
 
   for (auto result : op->getOpResults()) {
     std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
     auto shapedType = dyn_cast<ShapedType>(result.getType());
-    if (!shapedType)
+    if (!shapedType || !tileShape)
       continue;
 
-    if (tileShape && !llvm::equal(*tileShape, shapedType.getShape()))
+    if (!llvm::equal(*tileShape, shapedType.getShape()))
       return true;
   }
   return false;

>From 6cffa443d1c11197106d076e21da9fa973592fe8 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 20 May 2025 15:42:06 +0000
Subject: [PATCH 36/55] refactor

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 67 +++++++++----------
 1 file changed, 32 insertions(+), 35 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index b4ff5856b0b6c..9c839f0c056f8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -216,6 +216,18 @@ void XeGPUBlockingPass::runOnOperation() {
   // operation is replaced.
   xegpu::setLayoutAttrs(mod, [&](Value v) { return xegpu::getLayoutAttr(v); });
 
+  auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
+                                 xegpu::LayoutAttr layout) {
+    int count = 1;
+    SmallVector<int64_t> tileShape(shape);
+    if (layout && layout.getInstData()) {
+      DenseI32ArrayAttr instData = layout.getInstData();
+      tileShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
+      count = computeProduct(shape) / computeProduct(tileShape);
+    }
+    return std::make_pair(tileShape, count);
+  };
+
   // Perform type conversion for SCF control folow ops
   TypeConverter converter;
   converter.addConversion([&](Type type) -> Type { return type; });
@@ -225,56 +237,41 @@ void XeGPUBlockingPass::runOnOperation() {
         Type elemTy = type.getElementType();
         ArrayRef<int64_t> shape = type.getShape();
 
-        // init count and subShape to the default value. If the LayoutAttr
-        // is not present, it will return a VectorType with original shape.
-        int count = 1;
-        SmallVector<int64_t> subShape(shape);
-        if (auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
-                type.getEncoding())) {
-          if (layout.isWgLayout())
-            return failure();
-          if (DenseI32ArrayAttr instData = layout.getInstData()) {
-            // for unrolling, the subShape is determined by inst_data
-            subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
-            count = computeProduct(shape) / computeProduct(subShape);
-          }
-        }
+        auto layout =
+            llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
+        if (layout && layout.isWgLayout())
+          return failure();
+
+        int count;
+        SmallVector<int64_t> subShape;
+        std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
         auto newTy = VectorType::get(subShape, elemTy);
         result.append(count, newTy);
         return success();
       });
-
   converter.addConversion(
       [&](xegpu::TensorDescType type,
           SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
-        MLIRContext *ctx = type.getContext();
         Type elemTy = type.getElementType();
-        Attribute encoding = type.getEncoding();
         ArrayRef<int64_t> shape = type.getShape();
 
-        // init count and newTy to the default value. If the layout
-        // attribute is not present, it will return the original type.
-        int count = 1;
-        SmallVector<int64_t> subShape(shape);
-
         xegpu::LayoutAttr layout = type.getLayoutAttr();
+        if (layout && layout.isWgLayout())
+          return failure();
+
+        int count;
+        SmallVector<int64_t> subShape;
+        std::tie(subShape, count) = getTileShapeAndCount(shape, layout);
 
-        if (layout) {
-          if (layout.isWgLayout())
-            return failure();
-
-          if (DenseI32ArrayAttr instData = layout.getInstData()) {
-            // for unrolling, the subShape is determined by inst_data
-            subShape = llvm::to_vector_of<int64_t>(instData.asArrayRef());
-            count = computeProduct(shape) / computeProduct(subShape);
-            layout = layout.dropInstData();
-          }
-        }
-        auto newTy =
-            xegpu::TensorDescType::get(ctx, subShape, elemTy, encoding, layout);
+        if (layout)
+          layout = layout.dropInstData();
+
+        auto newTy = xegpu::TensorDescType::get(
+            type.getContext(), subShape, elemTy, type.getEncoding(), layout);
         result.append(count, newTy);
         return success();
       });
+
   xegpu::doSCFStructuralTypeConversionWithTensorType(mod, converter);
 
   xegpu::UnrollOptions options;

>From e023c1a235a7a452570b2cdb2ccb6851df2c9b7d Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 22 May 2025 17:52:06 +0000
Subject: [PATCH 37/55] add a corner unit test

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 40 ++++++++++++-----
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 14 +++---
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   | 43 +++++++++++++++++++
 3 files changed, 78 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 9c839f0c056f8..f8b5d4a9caaf9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -185,24 +185,44 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
   if (isa<LoopLikeOpInterface>(op))
     return false;
 
-  for (auto &opr : op->getOpOperands()) {
+  auto isUnrollable = [&](Value value,
+                          ArrayRef<int64_t> tileShape) -> std::optional<bool> {
+    Type valTy = value.getType();
+    if (auto tdesc = dyn_cast<xegpu::TensorDescType>(valTy)) {
+      xegpu::LayoutAttr layout = tdesc.getLayoutAttr();
+      if (!layout)
+        return std::nullopt;
+      if (layout.isWgLayout())
+        return false;
+      if (layout.getInstData())
+        return true;
+    }
+
+    auto shapedType = dyn_cast<ShapedType>(valTy);
+    if (shapedType && !llvm::equal(tileShape, shapedType.getShape()))
+      return true;
+
+    return std::nullopt;
+  };
+
+  for (OpOperand &opr : op->getOpOperands()) {
     std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
-    auto shapedType = dyn_cast<ShapedType>(opr.get().getType());
-    if (!shapedType || !tileShape)
+    if (!tileShape)
       continue;
 
-    if (!llvm::equal(*tileShape, shapedType.getShape()))
-      return true;
+    std::optional<bool> unrollable = isUnrollable(opr.get(), *tileShape);
+    if (unrollable.has_value())
+      return unrollable.value();
   }
 
-  for (auto result : op->getOpResults()) {
+  for (OpResult result : op->getOpResults()) {
     std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
-    auto shapedType = dyn_cast<ShapedType>(result.getType());
-    if (!shapedType || !tileShape)
+    if (!tileShape)
       continue;
 
-    if (!llvm::equal(*tileShape, shapedType.getShape()))
-      return true;
+    std::optional<bool> unrollable = isUnrollable(result, *tileShape);
+    if (unrollable.has_value())
+      return unrollable.value();
   }
   return false;
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index d9f69158f95eb..885477fe4cbd5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -136,7 +136,7 @@ struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
     ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
@@ -187,10 +187,9 @@ struct UnrollUpdateNdOffsetOp : public UnrollPattern<xegpu::UpdateNdOffsetOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     SmallVector<Type> convertedTdescTypes =
@@ -216,10 +215,9 @@ struct UnrollPrefetchNdOp : public UnrollPattern<xegpu::PrefetchNdOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     SmallVector<Type> convertedTdescTypes =
@@ -243,10 +241,9 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     Location loc = op.getLoc();
     VectorType valueTy = op.getType();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     Type elemTy = tdescTy.getElementType();
@@ -278,10 +275,9 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
     Location loc = op.getLoc();
     VectorType valueTy = op.getValueType();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
-    ArrayRef<int64_t> shape = tdescTy.getShape();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape || llvm::equal(*targetShape, shape))
+    if (!targetShape)
       return failure();
 
     SmallVector<Type> convertedValTypes =
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index d8a5dfe7d4b13..c9866b94dc79e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -82,6 +82,49 @@ gpu.module @test_kernel {
   }
 }
 
+// -----
+#l1 = #xegpu.layout<inst_data = [8, 16]>
+#l2 = #xegpu.layout<inst_data = [16, 16]>
+gpu.module @test_kernel {
+  gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+    %c0 = arith.constant 0 : index
+    %c8 = arith.constant 8 : index
+    %c16 = arith.constant 16 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+    %m = arith.muli %block_id_x, %c8 : index
+    %n = arith.muli %block_id_y, %c32 : index
+
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x32xf32, #l1>
+
+    //CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
+    %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<8x32xf32, #l1> -> vector<8x32xf32>
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16, #l1>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #l2>
+    %out:3 = scf.for %k = %c0 to %c1024 step %c16
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
+      -> (!xegpu.tensor_desc<8x16xf16, #l1>, !xegpu.tensor_desc<16x32xf16, #l2>, vector<8x32xf32>) {
+      //CHECK: %22 = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<8x16xf16, #l1> -> vector<8x16xf16>
+      //CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<16x32xf16, #l2> -> vector<16x32xf16>
+      %c = xegpu.dpas %a, %b, %arg2 {layout_result_0 = #l1}: vector<8x16xf16>, vector<16x32xf16>, vector<8x32xf32> -> vector<8x32xf32>
+      //CHECK: xegpu.update_nd_offset {{.*}} [%c0, %c32] : !xegpu.tensor_desc<8x16xf16>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [%c0, %c32] : !xegpu.tensor_desc<8x16xf16, #l1>
+      //CHECK-COUNT-2: xegpu.update_nd_offset {{.*}} [%c32, %c0] : !xegpu.tensor_desc<16x16xf16>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [%c32, %c0] : !xegpu.tensor_desc<16x32xf16, #l2>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c
+        : !xegpu.tensor_desc<8x16xf16, #l1>, !xegpu.tensor_desc<16x32xf16, #l2>, vector<8x32xf32>
+    }
+    //CHECK-COUNT-2: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
+    xegpu.store_nd %out#2, %c_tdesc: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #l1>
+    gpu.return
+  }
+}
+
 // -----
 #a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
 #b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>

>From d0e54ba8584c6a7ff165eee0f3da4be7c901a0bd Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 22 May 2025 18:53:34 +0000
Subject: [PATCH 38/55] stage

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 12cb8df0d30ea..95f08eaef9f02 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -94,6 +94,11 @@ computeTileShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
 ///
 /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
 /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
+
+/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
+/// pattern and all the other ops just follow.
+/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
+/// ops in the pass.
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 

>From 562f1c8f67f08dc3b273b7bafe6f803724a6b0aa Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Thu, 22 May 2025 19:15:21 +0000
Subject: [PATCH 39/55] update wg to sg unit test

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 1 +
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir              | 2 +-
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir                 | 4 ++--
 3 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 95f08eaef9f02..ad12cf34ca7b3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -315,6 +315,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
                                            resultTy.getElementType());
         tmpC = rewriter.create<xegpu::DpasOp>(loc, resTy, operands);
         xegpu::setLayoutAttr(cast<OpResult>(tmpC), originalLayout.dropSgLayoutAndData());
+
         newDpasOps.push_back(tmpC);
       }
     }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index bee026eb2084d..fa1e5fbae0954 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -85,7 +85,7 @@ gpu.module @test_round_robin_assignment {
     %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
       -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
     gpu.return
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index c6b232fb0d43e..22374f74b133e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -108,7 +108,7 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
       -> vector<32x24xf32>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
+      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
@@ -142,7 +142,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
       -> vector<32x24xf32>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout =  #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }

>From 39678106fd4ed4f8f79c23c05dbd4b29b275f66e Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Fri, 23 May 2025 20:34:27 +0000
Subject: [PATCH 40/55] fix comments

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 25 +++++--------------
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   | 12 ++++-----
 2 files changed, 12 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index f8b5d4a9caaf9..fcf9a09a8ffc0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -80,15 +80,14 @@ class XeGPUBlockingPass final
   void runOnOperation() override;
 
 private:
-  // Get the tile shape for a given value. If the value has a layout
-  // attribute and it is an SG layout, return the inst_data as the tile shape
-  // if inst_data is available; otherwise, return the original shape of the
-  // value. If the value does not have an SG layout, return std::nullopt.
-  std::optional<SmallVector<int64_t>>
-  getTileShape(TypedValue<ShapedType> value) const;
-
+  // Get the tile shape for a given operand by examining the layout attribute.
+  // If layout is not present or is not a subgroup level layout, it returns
+  // std::nullopt.
   std::optional<SmallVector<int64_t>> getTileShape(OpOperand &operand) const;
 
+  // Get the tile shape for a given result by examining the layout attribute.
+  // If layout is not present or is not a subgroup level layout, it returns
+  // std::nullopt.
   std::optional<SmallVector<int64_t>> getTileShape(OpResult result) const;
 
   // Get the tile shape for a given operation.
@@ -101,18 +100,6 @@ class XeGPUBlockingPass final
 };
 } // namespace
 
-std::optional<SmallVector<int64_t>>
-XeGPUBlockingPass::getTileShape(TypedValue<ShapedType> value) const {
-  assert(value && "value must be non-null");
-  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(value);
-  if (layout && layout.isSgLayout()) {
-    if (auto inst_data = layout.getInstData())
-      return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
-    return llvm::to_vector(value.getType().getShape());
-  }
-  return std::nullopt;
-}
-
 std::optional<SmallVector<int64_t>>
 XeGPUBlockingPass::getTileShape(OpOperand &operand) const {
   xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index c9866b94dc79e..4fe3844dc1c39 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -4,7 +4,7 @@
 #b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
 #c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
 gpu.module @test_kernel {
-  gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+  gpu.func @test_gemm_with_one_to_n_lowering(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
     %c0 = arith.constant 0 : index
     %c16 = arith.constant 16 : index
     %c32 = arith.constant 32 : index
@@ -45,7 +45,7 @@ gpu.module @test_kernel {
 #l1 = #xegpu.layout<inst_data = [8, 16]>
 #l2 = #xegpu.layout<inst_data = [16, 16]>
 gpu.module @test_kernel {
-  gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+  gpu.func @test_gemm_with_inst_data_only_attribute(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
     %c0 = arith.constant 0 : index
     %c16 = arith.constant 16 : index
     %c32 = arith.constant 32 : index
@@ -86,7 +86,7 @@ gpu.module @test_kernel {
 #l1 = #xegpu.layout<inst_data = [8, 16]>
 #l2 = #xegpu.layout<inst_data = [16, 16]>
 gpu.module @test_kernel {
-  gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+  gpu.func @test_gemm_with_one_to_one_lowering(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
     %c0 = arith.constant 0 : index
     %c8 = arith.constant 8 : index
     %c16 = arith.constant 16 : index
@@ -130,7 +130,7 @@ gpu.module @test_kernel {
 #b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
 #c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
 gpu.module @test_kernel {
-  gpu.func @test_gemm(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+  gpu.func @test_gemm_with_elemwise_preop(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
     %c0 = arith.constant 0 : index
     %c16 = arith.constant 16 : index
     %c32 = arith.constant 32 : index
@@ -172,7 +172,7 @@ gpu.module @test_kernel {
 // -----
 #l = #xegpu.layout<inst_data = [8, 16]>
 gpu.module @test_kernel {
-  gpu.func @test_elementwise(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
+  gpu.func @test_elementwise_with_inst_data_only(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
     %c0 = arith.constant 0 : index
     %c32 = arith.constant 32 : index
     %c1024 = arith.constant 1024 : index
@@ -211,7 +211,7 @@ gpu.module @test_kernel {
 // -----
 #l = #xegpu.layout<inst_data = [8]>
 gpu.module @test_kernel {
-  gpu.func @test_elementwise(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
+  gpu.func @test_elementwise_1D(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf16>) {
     %c0 = arith.constant 0 : index
     %c32 = arith.constant 32 : index
     %c1024 = arith.constant 1024 : index

>From aebc327a494876e57219e236bd040b55b8d4bc76 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 14:41:49 +0000
Subject: [PATCH 41/55] remove unnecessary reference for lambda

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index fcf9a09a8ffc0..fefcaf7e73d41 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -172,8 +172,8 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
   if (isa<LoopLikeOpInterface>(op))
     return false;
 
-  auto isUnrollable = [&](Value value,
-                          ArrayRef<int64_t> tileShape) -> std::optional<bool> {
+  auto isUnrollable = [](Value value,
+                         ArrayRef<int64_t> tileShape) -> std::optional<bool> {
     Type valTy = value.getType();
     if (auto tdesc = dyn_cast<xegpu::TensorDescType>(valTy)) {
       xegpu::LayoutAttr layout = tdesc.getLayoutAttr();
@@ -221,7 +221,7 @@ void XeGPUBlockingPass::runOnOperation() {
   // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
   // This ensures that the LayoutAttr remains accessible even if the defining
   // operation is replaced.
-  xegpu::setLayoutAttrs(mod, [&](Value v) { return xegpu::getLayoutAttr(v); });
+  xegpu::setLayoutAttrs(mod, [](Value v) { return xegpu::getLayoutAttr(v); });
 
   auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
                                  xegpu::LayoutAttr layout) {
@@ -237,7 +237,7 @@ void XeGPUBlockingPass::runOnOperation() {
 
   // Perform type conversion for SCF control folow ops
   TypeConverter converter;
-  converter.addConversion([&](Type type) -> Type { return type; });
+  converter.addConversion([](Type type) -> Type { return type; });
   converter.addConversion(
       [&](RankedTensorType type,
           SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
@@ -283,7 +283,7 @@ void XeGPUBlockingPass::runOnOperation() {
 
   xegpu::UnrollOptions options;
   options.setFilterConstraint([&](Operation *op) -> LogicalResult {
-    return needsUnroll(op) ? success() : failure();
+    return success(needsUnroll(op));
   });
 
   options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
@@ -315,7 +315,7 @@ void XeGPUBlockingPass::runOnOperation() {
 
   (void)applyPatternsGreedily(mod, std::move(patterns));
 
-  mod->walk([&](Operation *op) {
+  mod->walk([](Operation *op) {
     if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
       resolveUnrealizedConversionCastOp(castOp);
 

>From 90e7563a2b7e09b3cc506946cc8afa960316606e Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 14:45:45 +0000
Subject: [PATCH 42/55] rename

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index fefcaf7e73d41..1473ccf6feeae 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -216,12 +216,12 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
 
 void XeGPUBlockingPass::runOnOperation() {
   MLIRContext *ctx = &getContext();
-  Operation *mod = getOperation();
+  Operation *op = getOperation();
 
   // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
   // This ensures that the LayoutAttr remains accessible even if the defining
   // operation is replaced.
-  xegpu::setLayoutAttrs(mod, [](Value v) { return xegpu::getLayoutAttr(v); });
+  xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); });
 
   auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape,
                                  xegpu::LayoutAttr layout) {
@@ -279,7 +279,7 @@ void XeGPUBlockingPass::runOnOperation() {
         return success();
       });
 
-  xegpu::doSCFStructuralTypeConversionWithTensorType(mod, converter);
+  xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
 
   xegpu::UnrollOptions options;
   options.setFilterConstraint([&](Operation *op) -> LogicalResult {
@@ -313,9 +313,9 @@ void XeGPUBlockingPass::runOnOperation() {
   populateXeGPUUnrollPatterns(patterns, options);
   vector::populateVectorUnrollPatterns(patterns, vectorOptions);
 
-  (void)applyPatternsGreedily(mod, std::move(patterns));
+  (void)applyPatternsGreedily(op, std::move(patterns));
 
-  mod->walk([](Operation *op) {
+  op->walk([](Operation *op) {
     if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
       resolveUnrealizedConversionCastOp(castOp);
 

>From f5bfc2f8f22e93c0168ffc4b72152bf9f88d9084 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 15:18:20 +0000
Subject: [PATCH 43/55] address comments

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 5 +----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp         | 6 ++----
 2 files changed, 3 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 1473ccf6feeae..1d034e5685ed3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -60,10 +60,7 @@ resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
           builder, castOp.getLoc(), inputs, shape);
       castOp->replaceAllUsesWith(ValueRange(result));
       castOp->erase();
-    }
-
-    // pack
-    if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
+    } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
       ArrayRef<int64_t> tileShape = outputTy.getShape();
       SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
           builder, castOp.getLoc(), inputs[0], tileShape);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 5e0e83ef2eed5..d8b3906468ea8 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -360,10 +360,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
           auto isTensorTy = [&](Type type) {
             return isa<RankedTensorType>(type);
           };
-          if (llvm::any_of(op->getOperandTypes(), isTensorTy) ||
-              llvm::any_of(op->getResultTypes(), isTensorTy))
-            return false;
-          return true;
+          return llvm::none_of(op->getOperandTypes(), isTensorTy) &&
+                 llvm::none_of(op->getResultTypes(), isTensorTy);
         });
     mlir::RewritePatternSet patterns(context);
     patterns.insert<UnrealizedConversionCastOpPattern>(context);

>From 598fbcede72a9269cd14e4241ab6da9eb829edbe Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 15:18:43 +0000
Subject: [PATCH 44/55] fix format

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 1d034e5685ed3..2ad757d7ed25d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -279,9 +279,8 @@ void XeGPUBlockingPass::runOnOperation() {
   xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
 
   xegpu::UnrollOptions options;
-  options.setFilterConstraint([&](Operation *op) -> LogicalResult {
-    return success(needsUnroll(op));
-  });
+  options.setFilterConstraint(
+      [&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
 
   options.setNativeShapeFn([&](Operation *op) { return getTileShape(op); });
 

>From ff11a0572326b85208acd04809651d1631a0e74e Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 15:59:54 +0000
Subject: [PATCH 45/55] add comments

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 3f5fe2cce4636..84c1dc1373ee5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -295,6 +295,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
     }
 
     LayoutAttr dropSgLayoutAndData() {
+      // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getInstData() && !getLaneLayout())
         return nullptr;
       return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
@@ -302,6 +303,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
     }
 
     LayoutAttr dropInstData() {
+      // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getSgLayout() && !getLaneLayout())
         return nullptr;
       return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,

>From 9f7f715a19eee82028121ad1b8f234104950c5f7 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 16:31:41 +0000
Subject: [PATCH 46/55] add comments

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 52 ++++++++++++-------
 1 file changed, 33 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 2ad757d7ed25d..7e627bfc81ac3 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -43,30 +43,44 @@ static void
 resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
   ValueRange inputs = castOp.getInputs();
   ValueRange outputs = castOp.getOutputs();
-
-  if (inputs.size() == 1 && outputs.size() == 1) {
-    castOp->replaceAllUsesWith(inputs);
+  if (inputs.empty() || outputs.empty()) {
+    LDBG("erase unrealized conversion cast op has no inputs/outputs.");
     castOp->erase();
+    return;
   }
 
   VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
   VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
-  if (inputTy && outputTy) {
-    OpBuilder builder(castOp);
-    // unpack
-    if (inputs.size() > 1 && outputs.size() == 1) {
-      ArrayRef<int64_t> shape = outputTy.getShape();
-      Value result = xegpu::createVectorWithShapeFromValues(
-          builder, castOp.getLoc(), inputs, shape);
-      castOp->replaceAllUsesWith(ValueRange(result));
-      castOp->erase();
-    } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
-      ArrayRef<int64_t> tileShape = outputTy.getShape();
-      SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
-          builder, castOp.getLoc(), inputs[0], tileShape);
-      castOp->replaceAllUsesWith(results);
-      castOp->erase();
-    }
+  if (!inputTy || !outputTy) {
+    LDBG("skip unrealized conversion cast op has non-vector inputs/outputs.");
+    return;
+  }
+
+  // We only interest in the case where all inputs and outputs have the
+  // identical types
+  if (llvm::any_of(castOp->getOperandTypes(),
+                   [&](Type t) { return t != inputTy; }) ||
+      llvm::any_of(castOp->getResultTypes(),
+                   [&](Type t) { return t != outputTy; })) {
+    LDBG("skip unrealized conversion cast op not emulating pack/unpack.");
+    return;
+  }
+
+  OpBuilder builder(castOp);
+  if (inputs.size() > 1 && outputs.size() == 1) {
+    // the castOp is emulating an unpack op
+    ArrayRef<int64_t> shape = outputTy.getShape();
+    Value result = xegpu::createVectorWithShapeFromValues(
+        builder, castOp.getLoc(), inputs, shape);
+    castOp->replaceAllUsesWith(ValueRange(result));
+    castOp->erase();
+  } else if (castOp.getNumResults() > 1 && castOp.getNumOperands() == 1) {
+    // the castOp is emulating a pack op
+    ArrayRef<int64_t> tileShape = outputTy.getShape();
+    SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue(
+        builder, castOp.getLoc(), inputs[0], tileShape);
+    castOp->replaceAllUsesWith(results);
+    castOp->erase();
   }
 }
 

>From b164d7b4d4224c4c53d6e9fa34bb238251172dbc Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 16:57:59 +0000
Subject: [PATCH 47/55] address comments

---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 14 +++++++-------
 1 file changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index d8b3906468ea8..7cede355b7561 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -165,17 +165,17 @@ void xegpu::setLayoutAttr(OpResult result, LayoutAttr layout) {
     owner->setAttr(name, layout);
 }
 
-void xegpu::setLayoutAttrs(Operation *mod,
+void xegpu::setLayoutAttrs(Operation *op,
                            function_ref<LayoutAttr(Value)> getLayoutImpl) {
-  mod->walk([&](Operation *op) {
-    for (OpResult result : op->getOpResults()) {
-      auto layout = getLayoutImpl(result);
-      setLayoutAttr(result, layout);
-    }
-    for (OpOperand &opr : op->getOpOperands()) {
+  op->walk([&](Operation *nestOp) {
+    for (OpOperand &opr : nestOp->getOpOperands()) {
       auto layout = getLayoutImpl(opr.get());
       setLayoutAttr(opr, layout);
     }
+    for (OpResult result : nestOp->getOpResults()) {
+      auto layout = getLayoutImpl(result);
+      setLayoutAttr(result, layout);
+    }
   });
 }
 

>From 554f4b414b3b29d9b4befd4beeee39f5a275e128 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 18:17:59 +0000
Subject: [PATCH 48/55] refactor

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 64 ++++++++-----------
 1 file changed, 28 insertions(+), 36 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7e627bfc81ac3..50f056dafe0d9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -180,49 +180,41 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
 }
 
 bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
-  if (isa<LoopLikeOpInterface>(op))
+  // skip the op if any of its operands or results has workgroup level layouts
+  bool hasWgLayoutOperands =
+      llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
+        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
+        return layout && layout.isWgLayout();
+      });
+  bool hasWgLayoutResults =
+      llvm::any_of(op->getOpResults(), [](OpResult result) {
+        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
+        return layout && layout.isWgLayout();
+      });
+  if (hasWgLayoutOperands || hasWgLayoutResults)
     return false;
 
-  auto isUnrollable = [](Value value,
-                         ArrayRef<int64_t> tileShape) -> std::optional<bool> {
+  auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
     Type valTy = value.getType();
-    if (auto tdesc = dyn_cast<xegpu::TensorDescType>(valTy)) {
-      xegpu::LayoutAttr layout = tdesc.getLayoutAttr();
-      if (!layout)
-        return std::nullopt;
-      if (layout.isWgLayout())
-        return false;
-      if (layout.getInstData())
-        return true;
+    if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) {
+      xegpu::LayoutAttr layout = tdescTy.getLayoutAttr();
+      return layout && layout.getInstData();
     }
-
     auto shapedType = dyn_cast<ShapedType>(valTy);
-    if (shapedType && !llvm::equal(tileShape, shapedType.getShape()))
-      return true;
-
-    return std::nullopt;
+    return shapedType && !llvm::equal(tileShape, shapedType.getShape());
   };
 
-  for (OpOperand &opr : op->getOpOperands()) {
-    std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
-    if (!tileShape)
-      continue;
-
-    std::optional<bool> unrollable = isUnrollable(opr.get(), *tileShape);
-    if (unrollable.has_value())
-      return unrollable.value();
-  }
-
-  for (OpResult result : op->getOpResults()) {
-    std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
-    if (!tileShape)
-      continue;
-
-    std::optional<bool> unrollable = isUnrollable(result, *tileShape);
-    if (unrollable.has_value())
-      return unrollable.value();
-  }
-  return false;
+  bool hasUnrollableOperands =
+      llvm::any_of(op->getOpOperands(), [&](OpOperand &opr) {
+        std::optional<SmallVector<int64_t>> tileShape = getTileShape(opr);
+        return tileShape.has_value() && isUnrollable(opr.get(), *tileShape);
+      });
+  bool hasUnrollableResults =
+      llvm::any_of(op->getOpResults(), [&](OpResult result) {
+        std::optional<SmallVector<int64_t>> tileShape = getTileShape(result);
+        return tileShape.has_value() && isUnrollable(result, *tileShape);
+      });
+  return hasUnrollableOperands || hasUnrollableResults;
 }
 
 void XeGPUBlockingPass::runOnOperation() {

>From d9f2e813c722b4ec56cfe9137e6e218dc2e42d8d Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 19:54:09 +0000
Subject: [PATCH 49/55] refactor getTileShape with template

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  6 +--
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 46 ++++++++-----------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  9 ++--
 3 files changed, 27 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 4077de593b109..a58d0122d0421 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -57,10 +57,10 @@ FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
                                                LayoutAttr layout);
 
 /// Return the attribute name for the OpOperand to attach LayoutAttr
-std::string getLayoutName(OpOperand &opr);
+std::string getLayoutName(const OpOperand &opr);
 
 /// Return the attribute name for the OpResult to attach LayoutAttr
-std::string getLayoutName(OpResult res);
+std::string getLayoutName(const OpResult res);
 
 /// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
 /// values, the LayoutAttr is extracted from the TensorDescType itself. For
@@ -71,7 +71,7 @@ LayoutAttr getLayoutAttr(Value value);
 /// Retrieves the LayoutAttr associated with a given OpOperand. It will
 /// first check the operand_layout_{id} of the owner operation. If not found,
 /// it will check the operand itself and its defining op.
-LayoutAttr getLayoutAttr(OpOperand &opr);
+LayoutAttr getLayoutAttr(const OpOperand &opr);
 
 /// Sets the LayoutAttr for a given OpOperand by attaching it to the owner
 void setLayoutAttr(OpOperand &opr, LayoutAttr layout);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 50f056dafe0d9..022bf14492588 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -91,15 +91,14 @@ class XeGPUBlockingPass final
   void runOnOperation() override;
 
 private:
-  // Get the tile shape for a given operand by examining the layout attribute.
-  // If layout is not present or is not a subgroup level layout, it returns
-  // std::nullopt.
-  std::optional<SmallVector<int64_t>> getTileShape(OpOperand &operand) const;
-
-  // Get the tile shape for a given result by examining the layout attribute.
-  // If layout is not present or is not a subgroup level layout, it returns
-  // std::nullopt.
-  std::optional<SmallVector<int64_t>> getTileShape(OpResult result) const;
+  // Get the tile shape for a given OpOperand or OpResult by examining the
+  // corresponding layout attribute. If layout is not present or is not a
+  // subgroup level layout, it returns std::nullopt.
+  template <typename T,
+            typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                        std::is_same_v<T, OpResult>>>
+  std::optional<SmallVector<int64_t>>
+  getTileShape(const T &operandOrResult) const;
 
   // Get the tile shape for a given operation.
   std::optional<SmallVector<int64_t>> getTileShape(Operation *op) const;
@@ -111,31 +110,24 @@ class XeGPUBlockingPass final
 };
 } // namespace
 
+template <typename T, typename>
 std::optional<SmallVector<int64_t>>
-XeGPUBlockingPass::getTileShape(OpOperand &operand) const {
-  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand);
-  if (layout && layout.isSgLayout()) {
-    if (auto inst_data = layout.getInstData())
-      return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
-
-    if (auto type = dyn_cast<ShapedType>(operand.get().getType()))
-      return llvm::to_vector(type.getShape());
-  }
-  LDBG("failed to getTileShape for operand: " << operand.get());
-  return std::nullopt;
-}
-
-std::optional<SmallVector<int64_t>>
-XeGPUBlockingPass::getTileShape(OpResult result) const {
-  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
+XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
+  Value value;
+  if constexpr (std::is_same_v<T, OpOperand>)
+    value = operandOrResult.get();
+  else
+    value = (Value)operandOrResult;
+
+  xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
   if (layout && layout.isSgLayout()) {
     if (auto inst_data = layout.getInstData())
       return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
 
-    if (auto type = dyn_cast<ShapedType>(result.getType()))
+    if (auto type = dyn_cast<ShapedType>(value.getType()))
       return llvm::to_vector(type.getShape());
   }
-  LDBG("failed to getTileShape for result: " << result);
+  LDBG("failed to getTileShape for: " << value);
   return std::nullopt;
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 7cede355b7561..39c274850c7cc 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -101,12 +101,13 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
   return xegpu::getDistributedVectorType(helperTdescTy);
 }
 
-std::string xegpu::getLayoutName(OpOperand &opr) {
+std::string xegpu::getLayoutName(const OpOperand &opr) {
   const StringRef prefix("layout_operand_");
-  return llvm::formatv("{0}{1}", prefix, opr.getOperandNumber()).str();
+  unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
+  return llvm::formatv("{0}{1}", prefix, idx).str();
 }
 
-std::string xegpu::getLayoutName(OpResult res) {
+std::string xegpu::getLayoutName(const OpResult res) {
   const StringRef prefix = "layout_result_";
   return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
 }
@@ -143,7 +144,7 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
   return nullptr;
 }
 
-xegpu::LayoutAttr xegpu::getLayoutAttr(OpOperand &opr) {
+xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
   Operation *op = opr.getOwner();
   std::string layoutName = xegpu::getLayoutName(opr);
   if (op->hasAttr(layoutName))

>From 18e49f6bbf2e8d6fd0fd0fa4a429998778772d5c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 20:01:28 +0000
Subject: [PATCH 50/55] add qualifiers

---
 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 4 ++--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp        | 6 +++---
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index a58d0122d0421..942664deba9dd 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -66,7 +66,7 @@ std::string getLayoutName(const OpResult res);
 /// values, the LayoutAttr is extracted from the TensorDescType itself. For
 /// other values, it is obtained from the attributes of the defining operation.
 /// Returns nullptr if no LayoutAttr is found.
-LayoutAttr getLayoutAttr(Value value);
+LayoutAttr getLayoutAttr(const Value value);
 
 /// Retrieves the LayoutAttr associated with a given OpOperand. It will
 /// first check the operand_layout_{id} of the owner operation. If not found,
@@ -74,7 +74,7 @@ LayoutAttr getLayoutAttr(Value value);
 LayoutAttr getLayoutAttr(const OpOperand &opr);
 
 /// Sets the LayoutAttr for a given OpOperand by attaching it to the owner
-void setLayoutAttr(OpOperand &opr, LayoutAttr layout);
+void setLayoutAttr(const OpOperand &opr, const LayoutAttr layout);
 
 /// Set the LayoutAttr for the given OpResult by attching it to the defining op
 void setLayoutAttr(OpResult result, LayoutAttr layout);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 39c274850c7cc..69d653a4a45bb 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -112,7 +112,7 @@ std::string xegpu::getLayoutName(const OpResult res) {
   return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
 }
 
-xegpu::LayoutAttr xegpu::getLayoutAttr(Value value) {
+xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {
   if (!value)
     return nullptr;
 
@@ -152,14 +152,14 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
   return getLayoutAttr(opr.get());
 }
 
-void xegpu::setLayoutAttr(OpOperand &opr, LayoutAttr layout) {
+void xegpu::setLayoutAttr(const OpOperand &opr, const LayoutAttr layout) {
   auto owner = opr.getOwner();
   std::string name = xegpu::getLayoutName(opr);
   if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
     owner->setAttr(name, layout);
 }
 
-void xegpu::setLayoutAttr(OpResult result, LayoutAttr layout) {
+void xegpu::setLayoutAttr(const OpResult result, const LayoutAttr layout) {
   Operation *owner = result.getOwner();
   std::string name = xegpu::getLayoutName(result);
   if (layout && !owner->hasAttr(name))

>From 1f218f49c87e4f83e82580a7918e56904ae96677 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 20:03:04 +0000
Subject: [PATCH 51/55] add qualifiers

---
 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 942664deba9dd..ff9089ad9db18 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -77,7 +77,7 @@ LayoutAttr getLayoutAttr(const OpOperand &opr);
 void setLayoutAttr(const OpOperand &opr, const LayoutAttr layout);
 
 /// Set the LayoutAttr for the given OpResult by attching it to the defining op
-void setLayoutAttr(OpResult result, LayoutAttr layout);
+void setLayoutAttr(const OpResult result, const LayoutAttr layout);
 
 /// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
 /// If the operation contains regions, it is also applied recursively to the

>From f869b13f990809d8ba08a956d981c29677ff94f7 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 20:15:38 +0000
Subject: [PATCH 52/55] refactor setLayoutAttrs

---
 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 11 ++++++-----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp        | 14 ++++----------
 2 files changed, 10 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index ff9089ad9db18..e215a03b6d909 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -73,11 +73,12 @@ LayoutAttr getLayoutAttr(const Value value);
 /// it will check the operand itself and its defining op.
 LayoutAttr getLayoutAttr(const OpOperand &opr);
 
-/// Sets the LayoutAttr for a given OpOperand by attaching it to the owner
-void setLayoutAttr(const OpOperand &opr, const LayoutAttr layout);
-
-/// Set the LayoutAttr for the given OpResult by attching it to the defining op
-void setLayoutAttr(const OpResult result, const LayoutAttr layout);
+/// Sets the LayoutAttr for a given OpOperand or OpResult by attaching
+/// it to the owner's dictionary attributes
+template <typename T,
+          typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
+                                      std::is_same_v<T, OpResult>>>
+void setLayoutAttr(const T &operandOrResult, const LayoutAttr layout);
 
 /// Set the LayoutAttr for each OpOperand and OpResult of the given operation.
 /// If the operation contains regions, it is also applied recursively to the
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 69d653a4a45bb..56b5b6c2a0ac1 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -152,20 +152,14 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) {
   return getLayoutAttr(opr.get());
 }
 
-void xegpu::setLayoutAttr(const OpOperand &opr, const LayoutAttr layout) {
-  auto owner = opr.getOwner();
-  std::string name = xegpu::getLayoutName(opr);
+template <typename T, typename>
+void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) {
+  Operation *owner = operandOrResult.getOwner();
+  std::string name = xegpu::getLayoutName(operandOrResult);
   if (layout && !owner->hasAttrOfType<LayoutAttr>(name))
     owner->setAttr(name, layout);
 }
 
-void xegpu::setLayoutAttr(const OpResult result, const LayoutAttr layout) {
-  Operation *owner = result.getOwner();
-  std::string name = xegpu::getLayoutName(result);
-  if (layout && !owner->hasAttr(name))
-    owner->setAttr(name, layout);
-}
-
 void xegpu::setLayoutAttrs(Operation *op,
                            function_ref<LayoutAttr(Value)> getLayoutImpl) {
   op->walk([&](Operation *nestOp) {

>From de7585536d58d5b383221e21590fe75d0bdeea5a Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 27 May 2025 20:26:58 +0000
Subject: [PATCH 53/55] cleanup unnecessary reference symbols

---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 56b5b6c2a0ac1..ea01a22aa5473 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -224,16 +224,16 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
     Operation *op, TypeConverter converter) {
   MLIRContext *context = op->getContext();
 
-  auto materializeCast = [&](OpBuilder &builder, Type type, ValueRange inputs,
-                             Location loc) -> Value {
+  auto materializeCast = [](OpBuilder &builder, Type type, ValueRange inputs,
+                            Location loc) -> Value {
     return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
         .getResult(0);
   };
 
   { // convert VectorType to RankedTensorType for SCF Structural ops
     TypeConverter converter;
-    converter.addConversion([&](Type type) -> Type { return type; });
-    converter.addConversion([&](VectorType type) -> Type {
+    converter.addConversion([](Type type) -> Type { return type; });
+    converter.addConversion([](VectorType type) -> Type {
       return RankedTensorType::get(type.getShape(), type.getElementType());
     });
     converter.addSourceMaterialization(materializeCast);
@@ -251,7 +251,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
   { // propagate the layout attribute to RankedTensorType by checking
     // BuiltInUnrealizedCastOps
     // for VectorType to RankedTensorType cast.
-    op->walk([&](UnrealizedConversionCastOp castOp) {
+    op->walk([](UnrealizedConversionCastOp castOp) {
       if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
         return WalkResult::skip();
 
@@ -289,7 +289,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
     });
 
     // using yieldOp as anchor to update the result type of its ParentOp
-    op->walk([&](scf::YieldOp yieldOp) {
+    op->walk([](scf::YieldOp yieldOp) {
       Operation *parentOp = yieldOp->getParentOp();
       for (OpResult r : parentOp->getOpResults()) {
         unsigned idx = r.getResultNumber();
@@ -351,8 +351,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
 
     mlir::ConversionTarget target(*context);
     target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
-        [&](UnrealizedConversionCastOp op) {
-          auto isTensorTy = [&](Type type) {
+        [](UnrealizedConversionCastOp op) {
+          auto isTensorTy = [](Type type) {
             return isa<RankedTensorType>(type);
           };
           return llvm::none_of(op->getOperandTypes(), isTensorTy) &&

>From beacf8abb64dc353f3c05ffc61233aff233fff9f Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 28 May 2025 14:21:03 +0000
Subject: [PATCH 54/55] update naming

---
 mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h | 4 ++--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp        | 8 ++++----
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index e215a03b6d909..f9327d63869c0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -57,10 +57,10 @@ FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
                                                LayoutAttr layout);
 
 /// Return the attribute name for the OpOperand to attach LayoutAttr
-std::string getLayoutName(const OpOperand &opr);
+std::string getLayoutName(const OpOperand &operand);
 
 /// Return the attribute name for the OpResult to attach LayoutAttr
-std::string getLayoutName(const OpResult res);
+std::string getLayoutName(const OpResult result);
 
 /// Retrieves the LayoutAttr associated with a given Value. For TensorDescType
 /// values, the LayoutAttr is extracted from the TensorDescType itself. For
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index ea01a22aa5473..974aac94f9699 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -101,15 +101,15 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
   return xegpu::getDistributedVectorType(helperTdescTy);
 }
 
-std::string xegpu::getLayoutName(const OpOperand &opr) {
+std::string xegpu::getLayoutName(const OpOperand &operand) {
   const StringRef prefix("layout_operand_");
-  unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
+  unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
   return llvm::formatv("{0}{1}", prefix, idx).str();
 }
 
-std::string xegpu::getLayoutName(const OpResult res) {
+std::string xegpu::getLayoutName(const OpResult result) {
   const StringRef prefix = "layout_result_";
-  return llvm::formatv("{0}{1}", prefix, res.getResultNumber()).str();
+  return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
 }
 
 xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) {

>From c4c7abdd15c949ab044ba5a235f5a344725d73d1 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 28 May 2025 20:38:15 +0000
Subject: [PATCH 55/55] refactor

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 30 ++++++++-----------
 1 file changed, 13 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 022bf14492588..fa666d8fa50c0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
 namespace xegpu {
@@ -43,29 +44,22 @@ static void
 resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) {
   ValueRange inputs = castOp.getInputs();
   ValueRange outputs = castOp.getOutputs();
-  if (inputs.empty() || outputs.empty()) {
-    LDBG("erase unrealized conversion cast op has no inputs/outputs.");
-    castOp->erase();
-    return;
-  }
 
-  VectorType inputTy = dyn_cast<VectorType>(inputs[0].getType());
-  VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
-  if (!inputTy || !outputTy) {
-    LDBG("skip unrealized conversion cast op has non-vector inputs/outputs.");
-    return;
-  }
+  auto hasIdenticalVectorTypes = [](ValueRange values) {
+    auto types = values.getTypes();
+    return llvm::all_of(types, [&](Type type) {
+      return isa<VectorType>(type) && type == types.front();
+    });
+  };
 
   // We only interest in the case where all inputs and outputs have the
-  // identical types
-  if (llvm::any_of(castOp->getOperandTypes(),
-                   [&](Type t) { return t != inputTy; }) ||
-      llvm::any_of(castOp->getResultTypes(),
-                   [&](Type t) { return t != outputTy; })) {
+  // identical VectorTypes
+  if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) {
     LDBG("skip unrealized conversion cast op not emulating pack/unpack.");
     return;
   }
 
+  VectorType outputTy = dyn_cast<VectorType>(outputs[0].getType());
   OpBuilder builder(castOp);
   if (inputs.size() > 1 && outputs.size() == 1) {
     // the castOp is emulating an unpack op
@@ -183,8 +177,10 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
         xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
         return layout && layout.isWgLayout();
       });
-  if (hasWgLayoutOperands || hasWgLayoutResults)
+  if (hasWgLayoutOperands || hasWgLayoutResults) {
+    LDBG("skip unrolling for op with workgroup level layout: " << *op);
     return false;
+  }
 
   auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) {
     Type valTy = value.getType();



More information about the Mlir-commits mailing list