[Mlir-commits] [mlir] [mlir][XeGPU] Add XeGPU Workgroup to Subgroup Distribution Pass (PR #140805)

Nishant Patel llvmlistbot at llvm.org
Tue May 20 14:24:36 PDT 2025


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/140805

This PR adds the XeGPU workgroup (wg) to subgroup (sg) pass. The wg to sg pass transforms the xegpu wg level operations to subgroup operations based on the sg_layout and sg_data attribute. The PR adds transformation patterns for following Ops

1. CreateNdDesc
2. LoadNd
3. StoreNd
4. PrefetchNd
5. UpdateNdOffset
6. Dpas

>From 1ed4cb5b381898728f850da43a10826493fce94b Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sat, 10 May 2025 17:04:39 +0000
Subject: [PATCH 01/18] Add XeGPUWgToSg pass

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  31 +-
 .../Dialect/XeGPU/Transforms/Transforms.h     |   4 +
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |   1 +
 .../Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp  | 374 ++++++++++++++++++
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |  65 +++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  81 ++++
 6 files changed, 544 insertions(+), 12 deletions(-)
 create mode 100644 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
 create mode 100644 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 3e81f2d0ed786..bdea88cfd7022 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
 
@@ -18,9 +17,7 @@ def XeGPUFoldAliasOps : Pass<"xegpu-fold-alias-ops"> {
     The pass folds aliasing ops into XeGPU ops that they operate on the original
     source references.
   }];
-  let dependentDialects = [
-      "memref::MemRefDialect", "xegpu::XeGPUDialect"
-  ];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect"];
 }
 
 def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
@@ -28,14 +25,24 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
   let description = [{
     The pass distributes subgroup level (SIMD) XeGPU ops to work items.
   }];
-  let dependentDialects = [
-      "memref::MemRefDialect", "xegpu::XeGPUDialect", "vector::VectorDialect"
-  ];
-  let options = [
-    Option<"printOnly", "print-analysis-only", "bool",
-         /*default=*/"false",
-         "Print the result of the subgroup map propagation analysis and exit.">
-  ];
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect"];
+  let options = [Option<
+      "printOnly", "print-analysis-only", "bool",
+      /*default=*/"false",
+      "Print the result of the subgroup map propagation analysis and exit.">];
+}
+
+def XeGPUWgToSg : Pass<"xegpu-wg-to-sg", "::mlir::gpu::GPUModuleOp"> {
+  let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
+  let description = [{
+    This transform pass distributes the workgroup level computation to
+    multiple subgroups based on the sg_layout and sg_data attributes.
+  }];
+
+  let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
+                           "vector::VectorDialect", "arith::ArithDialect",
+                           "gpu::GPUDialect", "index::IndexDialect"];
 }
 
 #endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 3e94021c7a1ea..388ba32e1eebb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 
+#include "mlir/Transforms/DialectConversion.h"
+
 namespace mlir {
 class RewritePatternSet;
 
@@ -18,6 +20,8 @@ namespace xegpu {
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns,
+                                 ConversionTarget &target);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 901e02d3c9cf5..b258921cc87fd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
   XeGPUSubgroupDistribute.cpp
+  XeGPUWgToSg.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
new file mode 100644
index 0000000000000..7969d37d67f04
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -0,0 +1,374 @@
+//===- XeGPUWgToSg.cpp - XeGPU WorkGroup to Subgroup Pass -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/Debug.h"
+#include <mlir/Dialect/GPU/IR/GPUDialect.h>
+#include <mlir/Dialect/Index/IR/IndexOps.h>
+#include <numeric>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUWGTOSG
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-wg-to-sg"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+using namespace mlir;
+
+namespace {
+
+// clang-format off
+/// This pattern transform the CreateNdDescOp to create a subgroup descriptor
+/// from a workgroup descriptor. It replaces the offsets and sizes with
+/// appropriate values for the subgroup.
+/// It uses round-robin distribution to create the subgroup descriptor.
+
+/// Following create_nd_desc operation:,
+///    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
+///       -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
+///           sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+/// is converted to 9 subgroup level operations based on the sg_layout & sg_data:
+///    %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
+///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+///
+/// The sg_layout and sg_data are dropped from the layout attribute as they are no longer needed.
+///
+/// 24x24 matrix distribution example:
+/// sg_layout = [4, 4], sg_data = [2, 2]
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+/// dist_unit_shape = [8, 8] --> sg_layout[i] * sg_data[i]
+///
+/// +------------------------+
+/// | 8x8 | 8x8 | 8x8 |      <- 3 tiles across
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |      <- 3 tiles down
+/// |-----+-----+-----|
+/// | 8x8 | 8x8 | 8x8 |
+/// +------------------------+
+///
+/// Each 8x8 tile is further subdivided among subgroups:
+/// +------------------------+
+/// | 2x2 2x2 2x2 2x2 |  <- 4 subgroups across (each handles 2 columns)
+/// | 2x2 2x2 2x2 2x2 |  <- 4 subgroups down (each handles 2 rows)
+/// | 2x2 2x2 2x2 2x2 |
+/// | 2x2 2x2 2x2 2x2 |
+/// +------------------------+
+///
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
+/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
+/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
+// clang-format on
+struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+  // Helper to extract mixed offsets into a Value array
+  SmallVector<Value> extractOffsets(ConversionPatternRewriter &rewriter,
+                                    xegpu::CreateNdDescOp op) const {
+    llvm::SmallVector<Value> offsets;
+    auto staticOffsets = op.getStaticOffsets();
+    auto dynamicOffsets = op.getOffsets();
+
+    for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
+      if (ShapedType::isDynamic(staticOffsets[i])) {
+        offsets.push_back(dynamicOffsets[j++]);
+      } else {
+        offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
+            op.getLoc(), staticOffsets[i]));
+      }
+    }
+    return offsets;
+  }
+
+  // Convert linear subgroup ID to 2D coordinates
+  // TODO: Delinearize for nD
+  SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
+                                           Location loc, Value sgID,
+                                           Value sgDimX, Value sgDimY) const {
+    return {rewriter.create<index::DivUOp>(loc, sgID, sgDimY),
+            rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
+  }
+
+  // Create a constant index value
+  Value createConstantIndex(ConversionPatternRewriter &rewriter, Location loc,
+                            int64_t value) const {
+    return rewriter.create<arith::ConstantIndexOp>(loc, value);
+  }
+
+  // Calculate global offset for each subgroup
+  SmallVector<OpFoldResult>
+  calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
+                         const SmallVector<Value> &originalOffsets,
+                         const SmallVector<Value> &localOffset,
+                         const SmallVector<int64_t> &distUnitBaseAddr) const {
+
+    Value constOffsetX =
+        createConstantIndex(rewriter, loc, distUnitBaseAddr[0]);
+    Value constOffsetY =
+        createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
+
+    // Compute offsets within entire tile
+    Value offsetX =
+        rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
+    Value offsetY =
+        rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
+
+    // Add to global offsets
+    size_t lastDimIndex = originalOffsets.size() - 1;
+    size_t secondLastDimIndex = lastDimIndex - 1;
+
+    Value globalOffsetX = rewriter.createOrFold<index::AddOp>(
+        loc, originalOffsets[secondLastDimIndex], offsetX);
+    Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
+        loc, originalOffsets[lastDimIndex], offsetY);
+
+    // Create final offset list
+    SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
+                                            originalOffsets.end());
+    globalOffsets[secondLastDimIndex] = globalOffsetX;
+    globalOffsets[lastDimIndex] = globalOffsetY;
+
+    return globalOffsets;
+  }
+
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    MLIRContext *ctx = op.getContext();
+    xegpu::TensorDescType tdescTy = op.getType();
+    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+    Type elemTy = tdescTy.getElementType();
+    ArrayRef<int64_t> wgShape = tdescTy.getShape();
+    ArrayRef<int64_t> sgShape =
+        llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
+    ArrayRef<int64_t> sgLayout =
+        llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
+
+    // Get the subgroup ID
+    auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
+
+    // Create constants for layout dimensions
+    SmallVector<Value> sgLayoutDim(sgLayout.size());
+    SmallVector<Value> sgDataDim(sgShape.size());
+
+    for (size_t i = 0; i < sgLayout.size(); i++) {
+      sgLayoutDim[i] = createConstantIndex(rewriter, loc, sgLayout[i]);
+      sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
+    }
+
+    // Delinearize the 1D subgroup id into nd coordinates
+    SmallVector<Value> sgIds = delinearizeSubgroupId(
+        rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
+
+    // Calculate distribution unit shape and local offsets for subgroup
+    SmallVector<int64_t> distUnitShape(sgLayout.size());
+    SmallVector<Value> localOffset(sgLayout.size());
+    for (size_t i = 0; i < sgLayout.size(); i++) {
+      distUnitShape[i] = sgLayout[i] * sgShape[i];
+      localOffset[i] =
+          rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
+    }
+
+    SmallVector<Value> originalOffsets = extractOffsets(rewriter, op);
+
+    xegpu::TensorDescType newTdescTy =
+        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+                                   layout.dropSgLayoutAndData());
+    SmallVector<Value> newCreateNdOps;
+    for (const SmallVector<int64_t> &distUnitBaseAddr :
+         StaticTileOffsetRange(wgShape, distUnitShape)) {
+      SmallVector<OpFoldResult> globalOffsets = calculateGlobalOffsets(
+          rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr);
+
+      auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
+          loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
+          op.getMixedStrides());
+      newCreateNdOps.push_back(newCreateNdOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+    return success();
+  }
+};
+
+/// This pattern transforms the LoadNdOp to load from a subgroup descriptor
+/// It creates a LoadNdOp op to load the new subgroup src tensor descriptors.
+struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    SmallVector<Value> newLoadOps;
+    for (auto src : adaptor.getTensorDesc()) {
+      xegpu::TensorDescType tdescTy =
+          dyn_cast<xegpu::TensorDescType>(src.getType());
+      ArrayRef<int64_t> srcShape = tdescTy.getShape();
+      VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
+      auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(op.getLoc(), newResTy,
+                                                        src, op->getAttrs());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return mlir::success();
+  }
+};
+
+/// This pattern transforms the StoreNdOp to store to a subgroup descriptor
+/// It creates a StoreNdOp op to store the updated values to the new subgroup
+/// src tensor descriptors.
+struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
+  using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
+      rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, t, op.getL1HintAttr(),
+                                        op.getL2HintAttr(), op.getL3HintAttr());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
+/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
+/// offsets of the new subgroup src tensor descriptors.
+struct WgToSgUpdateNdOffsetOp
+    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+  using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    llvm::SmallVector<Value> newUpdateTileOffsetOps;
+    for (auto tDesc : adaptor.getTensorDesc()) {
+      auto newUpdateTileOffsetOp = rewriter.create<xegpu::UpdateNdOffsetOp>(
+          op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
+          op.getConstOffsets());
+      newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
+    return success();
+  }
+};
+
+/// This pattern transforms the DpasOp to work at subgroup level.
+struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
+  using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType resultTy = op.getResult().getType();
+    if (resultTy.getRank() != 2)
+      return failure();
+
+    auto originalLayout =
+        llvm::dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    if (!originalLayout)
+      return failure();
+
+    SmallVector<Value> newDpasOps;
+    size_t i = 0;
+    for (auto aVec : adaptor.getLhs()) {
+      for (auto bVec : adaptor.getRhs()) {
+
+        llvm::SmallVector<Value> operands({aVec, bVec});
+        Value tmpC;
+        if (op.getAcc()) {
+          tmpC = adaptor.getAcc()[i++];
+          operands.push_back(tmpC);
+        }
+
+        ArrayRef<int64_t> aVecShape =
+            llvm::cast<VectorType>(aVec.getType()).getShape();
+        ArrayRef<int64_t> bVecShape =
+            llvm::cast<VectorType>(bVec.getType()).getShape();
+        VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
+                                           resultTy.getElementType());
+        tmpC = rewriter.create<xegpu::DpasOp>(
+            loc, resTy, operands,
+            llvm::ArrayRef<NamedAttribute>(
+                {"layout", originalLayout.dropSgLayoutAndData()}));
+        newDpasOps.push_back(tmpC);
+      }
+    }
+    rewriter.replaceOpWithMultiple(op, {newDpasOps});
+    return mlir::success();
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace xegpu {
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
+  patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp>(patterns.getContext());
+}
+} // namespace xegpu
+} // namespace mlir
+
+namespace {
+struct XeGPUWgToSgPass : public xegpu::impl::XeGPUWgToSgBase<XeGPUWgToSgPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void XeGPUWgToSgPass::runOnOperation() {
+  MLIRContext *ctx = &getContext();
+  RewritePatternSet patterns(ctx);
+  ConversionTarget target(*ctx);
+
+  auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
+    if (auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
+      return createOp.getType();
+    if (auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
+      return loadOp.getTensorDescType();
+    if (auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
+      return storeOp.getTensorDescType();
+    if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
+      return updateOp.getType();
+    return xegpu::TensorDescType();
+  };
+
+  auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
+    return !layout || layout.getSgLayout() == nullptr;
+  };
+
+  target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
+                               xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp>(
+      [=](Operation *op) -> bool {
+        auto tdescTy = getTensorDescType(op);
+        auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
+        return isLegal(layout);
+      });
+
+  target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
+    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
+    return isLegal(layout);
+  });
+
+  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+  xegpu::populateXeGPUWgToSgPatterns(patterns);
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    return signalPassFailure();
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
new file mode 100644
index 0000000000000..d0f225c3e7304
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+
+gpu.module @test_round_robin_assignment {
+  // CHECK: test_create_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
+      // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      gpu.return
+    }
+
+  // CHECK: test_load_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-COUNT-12: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<2x2xf32>
+      %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+      gpu.return
+    }
+
+  // CHECK: test_store_nd
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+      xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      gpu.return
+  }
+
+  // CHECK: test_update_nd
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_update_nd(%src: memref<24x32xf32>){
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-12: %[[UPDATE:.*]] = xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    %update = xegpu.update_nd_offset %tdesc, [0, 16] :  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    gpu.return
+  }
+
+  // CHECK: test_dpas
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
+  // CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
+  gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
+    // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}},
+    // %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32,
+    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-12:
+    // %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] :
+    // memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32,
+    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-9:
+    // %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] :
+    // memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32,
+    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-144:
+    // %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout =
+    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} :
+    // vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b:  !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
+    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x24xf32> -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+    gpu.return
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
new file mode 100644
index 0000000000000..c4c8881e65597
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+
+gpu.module @test_1_1_assignment {
+  // CHECK: test_create_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {  
+  // CHECK: %[[SGID:.*]] = gpu.subgroup_id
+  // CHECK: %[[C12:.*]] = arith.constant 12 : index
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK: %[[DIV:.*]] = index.divu %[[SGID]], %[[C4]]
+  // CHECK: %[[REM:.*]] = index.remu %[[SGID]], %[[C4]]
+  // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
+  // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[ADD1:.*]] = index.add %[[MUL1]], %[[C0]]
+  // CHECK: %[[ADD2:.*]] = index.add %[[MUL2]], %[[C0]]
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: gpu.return
+  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  gpu.return
+  }
+
+  // CHECK: test_load_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+    gpu.return
+  }
+
+  // CHECK: test_store_nd
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+    // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+    xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    gpu.return
+}
+
+// CHECK: test_update_nd
+// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+gpu.func @test_update_nd(%src: memref<24x32xf32>){
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  gpu.return
+}
+
+// CHECK: test_dpas
+// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
+gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
+    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}},
+    // {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32,
+    // #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> CHECK:
+    // %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] :
+    // !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8],
+    // lane_data = [1, 1]>> -> vector<12x8xf32> CHECK: %[[TDESC_B:.*]] =
+    // xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> ->
+    // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
+    // lane_data = [1, 1]>> CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] :
+    // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
+    // lane_data = [1, 1]>> -> vector<8x12xf32> CHECK: %[[DPAS:.*]] = xegpu.dpas
+    // %[[LOAD_A]], %[[LOAD_B]] {layout =  #xegpu.layout<lane_layout = [2, 2],
+    // lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> ->
+    // vector<12x12xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
+    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+    gpu.return
+  }
+}

>From b3bf12f082eb08aa3f82503142140fc686e0e950 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sun, 11 May 2025 15:49:35 +0000
Subject: [PATCH 02/18] Add prefetch_nd op

---
 .../Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp  | 52 ++++++++++++-------
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 25 ++++-----
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 29 ++++++-----
 3 files changed, 60 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 7969d37d67f04..5eabb04e3b858 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -34,11 +34,10 @@ using namespace mlir;
 namespace {
 
 // clang-format off
-/// This pattern transform the CreateNdDescOp to create a subgroup descriptor
+/// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
-/// It uses round-robin distribution to create the subgroup descriptor.
-
+/// It uses round-robin assignment to distribute the work to the subgroups.
 /// Following create_nd_desc operation:,
 ///    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
 ///       -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
@@ -47,7 +46,7 @@ namespace {
 ///    %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
 ///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
 ///
-/// The sg_layout and sg_data are dropped from the layout attribute as they are no longer needed.
+/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
 ///
 /// 24x24 matrix distribution example:
 /// sg_layout = [4, 4], sg_data = [2, 2]
@@ -72,7 +71,6 @@ namespace {
 ///
 /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
 /// distribution units (3x3) in total. Hence the 9 subgroup level operations.
-/// Each 8x8 matrix within the 24x24 matrix is called a distribution unit.
 // clang-format on
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
@@ -110,7 +108,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     return rewriter.create<arith::ConstantIndexOp>(loc, value);
   }
 
-  // Calculate global offset for each subgroup
+  // Calculate offset for each subgroup
   SmallVector<OpFoldResult>
   calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
                          const SmallVector<Value> &originalOffsets,
@@ -122,13 +120,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     Value constOffsetY =
         createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
 
-    // Compute offsets within entire tile
     Value offsetX =
         rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
     Value offsetY =
         rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
 
-    // Add to global offsets
     size_t lastDimIndex = originalOffsets.size() - 1;
     size_t secondLastDimIndex = lastDimIndex - 1;
 
@@ -137,7 +133,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
         loc, originalOffsets[lastDimIndex], offsetY);
 
-    // Create final offset list
     SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
                                             originalOffsets.end());
     globalOffsets[secondLastDimIndex] = globalOffsetX;
@@ -172,7 +167,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
     }
 
-    // Delinearize the 1D subgroup id into nd coordinates
+    // Delinearize the 1D subgroup id into 2d
     SmallVector<Value> sgIds = delinearizeSubgroupId(
         rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
 
@@ -207,8 +202,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   }
 };
 
-/// This pattern transforms the LoadNdOp to load from a subgroup descriptor
-/// It creates a LoadNdOp op to load the new subgroup src tensor descriptors.
+/// This pattern transforms the LoadNdOp to load subgroup data.
 struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
   using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
   LogicalResult
@@ -310,7 +304,22 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
       }
     }
     rewriter.replaceOpWithMultiple(op, {newDpasOps});
-    return mlir::success();
+    return success();
+  }
+};
+
+/// This pattern transforms the PrefetchNdOp to prefetch the subgroup data.
+struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
+  using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    for (auto src : adaptor.getTensorDesc()) {
+      rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
+                                           op->getAttrs());
+    }
+    rewriter.eraseOp(op);
+    return success();
   }
 };
 
@@ -320,7 +329,8 @@ namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp>(patterns.getContext());
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
+      patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -345,6 +355,8 @@ void XeGPUWgToSgPass::runOnOperation() {
       return storeOp.getTensorDescType();
     if (auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
       return updateOp.getType();
+    if (auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
+      return prefetchOp.getTensorDescType();
     return xegpu::TensorDescType();
   };
 
@@ -353,12 +365,12 @@ void XeGPUWgToSgPass::runOnOperation() {
   };
 
   target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
-                               xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp>(
-      [=](Operation *op) -> bool {
-        auto tdescTy = getTensorDescType(op);
-        auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
-        return isLegal(layout);
-      });
+                               xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
+                               xegpu::PrefetchNdOp>([=](Operation *op) -> bool {
+    auto tdescTy = getTensorDescType(op);
+    auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(tdescTy.getLayout());
+    return isLegal(layout);
+  });
 
   target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool {
     auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(op->getAttr("layout"));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d0f225c3e7304..de2c548ec7ebb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -42,18 +42,10 @@ gpu.module @test_round_robin_assignment {
   // CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
   // CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
   gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
-    // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}},
-    // %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32,
-    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-12:
-    // %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] :
-    // memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32,
-    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-9:
-    // %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] :
-    // memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32,
-    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> CHECK-COUNT-144:
-    // %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout =
-    // #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} :
-    // vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+    // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-12: %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-9: %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-144: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
     %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
@@ -62,4 +54,13 @@ gpu.module @test_round_robin_assignment {
     %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
+
+  // CHECK: test_prefetch_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+    // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index c4c8881e65597..1cae2c822d826 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -57,20 +57,11 @@ gpu.func @test_update_nd(%src: memref<24x32xf32>){
 // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
 // CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
 gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
-    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}},
-    // {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32,
-    // #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> CHECK:
-    // %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] :
-    // !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8],
-    // lane_data = [1, 1]>> -> vector<12x8xf32> CHECK: %[[TDESC_B:.*]] =
-    // xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> ->
-    // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
-    // lane_data = [1, 1]>> CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] :
-    // !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2],
-    // lane_data = [1, 1]>> -> vector<8x12xf32> CHECK: %[[DPAS:.*]] = xegpu.dpas
-    // %[[LOAD_A]], %[[LOAD_B]] {layout =  #xegpu.layout<lane_layout = [2, 2],
-    // lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> ->
-    // vector<12x12xf32>
+    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> 
+    // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+    // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> 
+    // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<8x12xf32>
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
     %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
@@ -78,4 +69,14 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
+
+  // CHECK: test_prefetch_nd_tdesc
+  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: xegpu.prefetch_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    gpu.return
+  }
 }

>From 6a8647fa764e710f5aaeb51b46ae2ea398a959a3 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Sun, 11 May 2025 22:06:24 +0000
Subject: [PATCH 03/18] Remove braces for single statement for and if

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 5eabb04e3b858..836f307ece9e1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -83,12 +83,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     auto dynamicOffsets = op.getOffsets();
 
     for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
-      if (ShapedType::isDynamic(staticOffsets[i])) {
+      if (ShapedType::isDynamic(staticOffsets[i]))
         offsets.push_back(dynamicOffsets[j++]);
-      } else {
+      else
         offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
             op.getLoc(), staticOffsets[i]));
-      }
     }
     return offsets;
   }
@@ -314,10 +313,9 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   LogicalResult
   matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    for (auto src : adaptor.getTensorDesc()) {
+    for (auto src : adaptor.getTensorDesc())
       rewriter.create<xegpu::PrefetchNdOp>(op.getLoc(), TypeRange(), src,
                                            op->getAttrs());
-    }
     rewriter.eraseOp(op);
     return success();
   }

>From c6589299e4e1375e91b61fbf0edb9f3d1f7a89c4 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 12 May 2025 14:30:54 +0000
Subject: [PATCH 04/18] Clean up

---
 mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h | 5 +----
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp       | 3 +++
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 388ba32e1eebb..5c12973edbed8 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -9,8 +9,6 @@
 #ifndef MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 #define MLIR_DIALECT_XEGPU_TRANSFORMS_TRANSFORMS_H
 
-#include "mlir/Transforms/DialectConversion.h"
-
 namespace mlir {
 class RewritePatternSet;
 
@@ -20,8 +18,7 @@ namespace xegpu {
 void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
-void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns,
-                                 ConversionTarget &target);
+void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns);
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 836f307ece9e1..512cdca251c42 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -147,6 +147,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
     auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+    if (!layout)
+      return failure();
     Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
     ArrayRef<int64_t> sgShape =
@@ -154,6 +156,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     ArrayRef<int64_t> sgLayout =
         llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
 
+    // TODO : Handle order attribute
     // Get the subgroup ID
     auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
 

>From 7f4e202ef2dbca83f19fe69eb486b315bf2d1853 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 13 May 2025 20:59:54 +0000
Subject: [PATCH 05/18] Fix CI

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 512cdca251c42..68a5f7faa2fbd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -151,9 +151,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       return failure();
     Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
-    ArrayRef<int64_t> sgShape =
+    SmallVector<int64_t> sgShape =
         llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
-    ArrayRef<int64_t> sgLayout =
+    SmallVector<int64_t> sgLayout =
         llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
 
     // TODO : Handle order attribute
@@ -188,7 +188,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
         xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
                                    layout.dropSgLayoutAndData());
     SmallVector<Value> newCreateNdOps;
-    for (const SmallVector<int64_t> &distUnitBaseAddr :
+    for (SmallVector<int64_t> distUnitBaseAddr :
          StaticTileOffsetRange(wgShape, distUnitShape)) {
       SmallVector<OpFoldResult> globalOffsets = calculateGlobalOffsets(
           rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr);

>From 2153a8a281726cf31b60f973907842e790bddc64 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 14 May 2025 18:26:48 +0000
Subject: [PATCH 06/18] Address feedback

---
 .../Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp  | 21 +++--
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 66 +++++++++-------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 76 ++++++++++++-------
 3 files changed, 94 insertions(+), 69 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
index 68a5f7faa2fbd..f8478289d3b91 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
@@ -1,4 +1,4 @@
-//===- XeGPUWgToSg.cpp - XeGPU WorkGroup to Subgroup Pass -------===//
+//===- XeGPUWgToSg.cpp - XeGPU Workgroup to Subgroup Pass -----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -25,15 +25,10 @@ namespace xegpu {
 } // namespace xegpu
 } // namespace mlir
 
-#define DEBUG_TYPE "xegpu-wg-to-sg"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
-
 using namespace mlir;
 
 namespace {
 
-// clang-format off
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -42,11 +37,14 @@ namespace {
 ///    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x24xf32>
 ///       -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4],
 ///           sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-/// is converted to 9 subgroup level operations based on the sg_layout & sg_data:
+/// is converted to 9 subgroup level operations based on the sg_layout &
+/// sg_data:
 ///    %tdesc = xegpu.create_nd_tdesc %src[off1, off2] : memref<24x24xf32> ->
-///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+///           !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2],
+///           lane_data = [1, 1]>>
 ///
-/// The sg_layout and sg_data attributes are dropped after the pass as they are no longer needed.
+/// The sg_layout and sg_data attributes are dropped after the pass as they are
+/// no longer needed.
 ///
 /// 24x24 matrix distribution example:
 /// sg_layout = [4, 4], sg_data = [2, 2]
@@ -69,9 +67,8 @@ namespace {
 /// | 2x2 2x2 2x2 2x2 |
 /// +------------------------+
 ///
-/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be 9
-/// distribution units (3x3) in total. Hence the 9 subgroup level operations.
-// clang-format on
+/// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
+/// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index de2c548ec7ebb..3096759e3ac8c 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,62 +1,70 @@
 // RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
 
 gpu.module @test_round_robin_assignment {
-  // CHECK: test_create_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_create_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
-      // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
+      // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       gpu.return
     }
 
-  // CHECK: test_load_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_load_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-      // CHECK-COUNT-12: %[[LOAD:.*]] = xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<2x2xf32>
+      // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
+      // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
       %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
       gpu.return
     }
 
-  // CHECK: test_store_nd
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_store_nd
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_store_nd(%src: memref<24x32xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-      // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}} : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
+      // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
       %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
       xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       gpu.return
   }
 
-  // CHECK: test_update_nd
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_update_nd
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_update_nd(%src: memref<24x32xf32>){
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    // CHECK-COUNT-12: %[[UPDATE:.*]] = xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
+    // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
     %update = xegpu.update_nd_offset %tdesc, [0, 16] :  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
 
-  // CHECK: test_dpas
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
-  // CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
-  // CHECK: %[[ARG_2:.*]]: memref<24x24xf32>
-  gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>, %c: memref<24x24xf32>) {
-    // CHECK-COUNT-12: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    // CHECK-COUNT-12: %[[TDESC1:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    // CHECK-COUNT-9: %[[TDESC2:.*]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<24x24xf32> -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    // CHECK-COUNT-144: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %load_b =  xegpu.load_nd %tdesc_b:  !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
-    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<24x24xf32> -> !xegpu.tensor_desc<24x24xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+  // CHECK-LABEL: test_dpas
+  // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>)
+  gpu.func @test_dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-4:  xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
+    // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b:  !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
+    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
     gpu.return
   }
 
-  // CHECK: test_prefetch_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_prefetch_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
     // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 1cae2c822d826..fdc10289b44f0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1,8 +1,8 @@
 // RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
 
 gpu.module @test_1_1_assignment {
-  // CHECK: test_create_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_create_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {  
   // CHECK: %[[SGID:.*]] = gpu.subgroup_id
   // CHECK: %[[C12:.*]] = arith.constant 12 : index
@@ -15,53 +15,71 @@ gpu.module @test_1_1_assignment {
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[ADD1:.*]] = index.add %[[MUL1]], %[[C0]]
   // CHECK: %[[ADD2:.*]] = index.add %[[MUL2]], %[[C0]]
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
+  // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
   // CHECK: gpu.return
   %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   gpu.return
   }
 
-  // CHECK: test_load_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_load_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
+    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<12x8xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
     gpu.return
   }
 
-  // CHECK: test_store_nd
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_store_nd
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_store_nd(%src: memref<24x32xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
-    // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]] : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
+    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<12x8xf32>
+    // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
+    // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
     xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     gpu.return
 }
 
-// CHECK: test_update_nd
-// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+// CHECK-LABEL: test_update_nd
+// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
 gpu.func @test_update_nd(%src: memref<24x32xf32>){
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-  // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+  // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
+  // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
   %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   gpu.return
 }
 
-// CHECK: test_dpas
-// CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
-// CHECK: %[[ARG_1:.*]]: memref<32x24xf32>
+// CHECK-LABEL: test_dpas
+// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
+// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
 gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
-    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> 
-    // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<12x8xf32>
-    // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32> -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> 
-    // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]] : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<8x12xf32>
-    // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]] {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
+    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+    // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
+    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<12x8xf32>
+    // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
+    // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<8x12xf32>
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
+    // CHECK-SAME: {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
     %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
@@ -70,11 +88,13 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
-  // CHECK: test_prefetch_nd_tdesc
-  // CHECK: %[[ARG_0:.*]]: memref<24x32xf32>
+  // CHECK-LABEL: test_prefetch_nd_tdesc
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    // CHECK: xegpu.prefetch_nd %[[TDESC]] : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: xegpu.prefetch_nd %[[TDESC]]
+    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     gpu.return

>From 46686f5e36744c2639e4eb4bfe30f84b7580f306 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 15 May 2025 21:53:42 +0000
Subject: [PATCH 07/18] change name to XeGPUWgToSgDistribute

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  2 +-
 .../Dialect/XeGPU/Transforms/Transforms.h     |  2 +-
 .../Dialect/XeGPU/Transforms/CMakeLists.txt   |  2 +-
 ...PUWgToSg.cpp => XeGPUWgToSgDistribute.cpp} | 32 ++++++++-----------
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |  2 +-
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  2 +-
 6 files changed, 18 insertions(+), 24 deletions(-)
 rename mlir/lib/Dialect/XeGPU/Transforms/{XeGPUWgToSg.cpp => XeGPUWgToSgDistribute.cpp} (94%)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index bdea88cfd7022..0be9fceb25ef1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -33,7 +33,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
       "Print the result of the subgroup map propagation analysis and exit.">];
 }
 
-def XeGPUWgToSg : Pass<"xegpu-wg-to-sg", "::mlir::gpu::GPUModuleOp"> {
+def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute", "::mlir::gpu::GPUModuleOp"> {
   let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
   let description = [{
     This transform pass distributes the workgroup level computation to
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
index 1029c66f97461..44b81796b1313 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h
@@ -62,7 +62,7 @@ void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
 
 /// Appends patterns for XeGPU SIMT distribution into `patterns`.
 void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
-void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns);
+void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns);
 
 /// Collect a set of patterns to unroll xegpu operations to a smaller shapes.
 /// Users can control whether an operation to be unrolled or not, as well as
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 81938ba1d5ba3..837303b04e9d7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   XeGPUFoldAliasOps.cpp
   XeGPUSubgroupDistribute.cpp
   XeGPUUnroll.cpp
-  XeGPUWgToSg.cpp
+  XeGPUWgToSgDistribute.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
similarity index 94%
rename from mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
rename to mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index f8478289d3b91..6406809b8b9c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSg.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1,4 +1,4 @@
-//===- XeGPUWgToSg.cpp - XeGPU Workgroup to Subgroup Pass -----------------===//
+//===- XeGPUWgToSgDistribute.cpp - XeGPU Workgroup to Subgroup Pass -------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -13,14 +13,12 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/Debug.h"
 #include <mlir/Dialect/GPU/IR/GPUDialect.h>
 #include <mlir/Dialect/Index/IR/IndexOps.h>
-#include <numeric>
 
 namespace mlir {
 namespace xegpu {
-#define GEN_PASS_DEF_XEGPUWGTOSG
+#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
 } // namespace xegpu
 } // namespace mlir
@@ -98,12 +96,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
             rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
   }
 
-  // Create a constant index value
-  Value createConstantIndex(ConversionPatternRewriter &rewriter, Location loc,
-                            int64_t value) const {
-    return rewriter.create<arith::ConstantIndexOp>(loc, value);
-  }
-
   // Calculate offset for each subgroup
   SmallVector<OpFoldResult>
   calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
@@ -112,9 +104,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
                          const SmallVector<int64_t> &distUnitBaseAddr) const {
 
     Value constOffsetX =
-        createConstantIndex(rewriter, loc, distUnitBaseAddr[0]);
+        rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[0]);
     Value constOffsetY =
-        createConstantIndex(rewriter, loc, distUnitBaseAddr[1]);
+        rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[1]);
 
     Value offsetX =
         rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
@@ -162,8 +154,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     SmallVector<Value> sgDataDim(sgShape.size());
 
     for (size_t i = 0; i < sgLayout.size(); i++) {
-      sgLayoutDim[i] = createConstantIndex(rewriter, loc, sgLayout[i]);
-      sgDataDim[i] = createConstantIndex(rewriter, loc, sgShape[i]);
+      sgLayoutDim[i] =
+          rewriter.create<arith::ConstantIndexOp>(loc, sgLayout[i]);
+      sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
     }
 
     // Delinearize the 1D subgroup id into 2d
@@ -278,8 +271,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
       return failure();
 
     SmallVector<Value> newDpasOps;
-    size_t i = 0;
     for (auto aVec : adaptor.getLhs()) {
+      size_t i = 0;
       for (auto bVec : adaptor.getRhs()) {
 
         llvm::SmallVector<Value> operands({aVec, bVec});
@@ -325,7 +318,7 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
 
 namespace mlir {
 namespace xegpu {
-void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
+void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
       patterns.getContext());
@@ -334,12 +327,13 @@ void populateXeGPUWgToSgPatterns(RewritePatternSet &patterns) {
 } // namespace mlir
 
 namespace {
-struct XeGPUWgToSgPass : public xegpu::impl::XeGPUWgToSgBase<XeGPUWgToSgPass> {
+struct XeGPUWgToSgDistributePass
+    : public xegpu::impl::XeGPUWgToSgDistributeBase<XeGPUWgToSgDistributePass> {
   void runOnOperation() override;
 };
 } // namespace
 
-void XeGPUWgToSgPass::runOnOperation() {
+void XeGPUWgToSgDistributePass::runOnOperation() {
   MLIRContext *ctx = &getContext();
   RewritePatternSet patterns(ctx);
   ConversionTarget target(*ctx);
@@ -377,7 +371,7 @@ void XeGPUWgToSgPass::runOnOperation() {
 
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
-  xegpu::populateXeGPUWgToSgPatterns(patterns);
+  xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))
     return signalPassFailure();
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 3096759e3ac8c..321cc0510a24c 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
 gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: test_create_nd_tdesc
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index fdc10289b44f0..3bd95ee775db3 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --xegpu-wg-to-sg -split-input-file %s | FileCheck %s
+// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
 gpu.module @test_1_1_assignment {
   // CHECK-LABEL: test_create_nd_tdesc

>From b8da87e3d9f85c12a89cccf1092dcbbac22732b2 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 16 May 2025 05:14:56 +0000
Subject: [PATCH 08/18] Use getMixedOffsets

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 43 +++++++++----------
 1 file changed, 20 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6406809b8b9c7..7c5a6d362c3d1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -70,23 +70,6 @@ namespace {
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 
-  // Helper to extract mixed offsets into a Value array
-  SmallVector<Value> extractOffsets(ConversionPatternRewriter &rewriter,
-                                    xegpu::CreateNdDescOp op) const {
-    llvm::SmallVector<Value> offsets;
-    auto staticOffsets = op.getStaticOffsets();
-    auto dynamicOffsets = op.getOffsets();
-
-    for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
-      if (ShapedType::isDynamic(staticOffsets[i]))
-        offsets.push_back(dynamicOffsets[j++]);
-      else
-        offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
-            op.getLoc(), staticOffsets[i]));
-    }
-    return offsets;
-  }
-
   // Convert linear subgroup ID to 2D coordinates
   // TODO: Delinearize for nD
   SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
@@ -99,7 +82,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   // Calculate offset for each subgroup
   SmallVector<OpFoldResult>
   calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
-                         const SmallVector<Value> &originalOffsets,
+                         const SmallVector<OpFoldResult> &originalOffsets,
                          const SmallVector<Value> &localOffset,
                          const SmallVector<int64_t> &distUnitBaseAddr) const {
 
@@ -116,10 +99,24 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     size_t lastDimIndex = originalOffsets.size() - 1;
     size_t secondLastDimIndex = lastDimIndex - 1;
 
-    Value globalOffsetX = rewriter.createOrFold<index::AddOp>(
-        loc, originalOffsets[secondLastDimIndex], offsetX);
-    Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
-        loc, originalOffsets[lastDimIndex], offsetY);
+    // Convert originalOffsets to Value
+    auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
+      if (auto val = ofr.dyn_cast<Value>())
+        return val;
+      if (auto attr = ofr.dyn_cast<Attribute>()) {
+        int64_t staticOffset = cast<IntegerAttr>(attr).getInt();
+        return rewriter.create<arith::ConstantIndexOp>(loc, staticOffset);
+      }
+      llvm_unreachable("Unsupported OpFoldResult kind");
+    };
+
+    Value origOffsetX =
+        getValueFromOpFoldResult(originalOffsets[secondLastDimIndex]);
+    Value origOffsetY = getValueFromOpFoldResult(originalOffsets[lastDimIndex]);
+    Value globalOffsetX =
+        rewriter.createOrFold<index::AddOp>(loc, origOffsetX, offsetX);
+    Value globalOffsetY =
+        rewriter.createOrFold<index::AddOp>(loc, origOffsetY, offsetY);
 
     SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
                                             originalOffsets.end());
@@ -172,7 +169,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
           rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
     }
 
-    SmallVector<Value> originalOffsets = extractOffsets(rewriter, op);
+    SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
 
     xegpu::TensorDescType newTdescTy =
         xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),

>From b3ba670e96a0cdd0afcb74953764779cdcc6fb66 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 16 May 2025 17:58:57 +0000
Subject: [PATCH 09/18] Address feedback

---
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |  2 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 27 ++++-----
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 59 ++++++++++++++-----
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 56 ++++++++++++------
 4 files changed, 95 insertions(+), 49 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 0be9fceb25ef1..6f585f9ceb29b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -33,7 +33,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
       "Print the result of the subgroup map propagation analysis and exit.">];
 }
 
-def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute", "::mlir::gpu::GPUModuleOp"> {
+def XeGPUWgToSgDistribute : Pass<"xegpu-wg-to-sg-distribute"> {
   let summary = "Transform WorkGroup level XeGPU code to SubGroup level";
   let description = [{
     This transform pass distributes the workgroup level computation to
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 7c5a6d362c3d1..20fc6951c9481 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -7,14 +7,15 @@
 //===----------------------------------------------------------------------===//
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include <mlir/Dialect/GPU/IR/GPUDialect.h>
-#include <mlir/Dialect/Index/IR/IndexOps.h>
 
 namespace mlir {
 namespace xegpu {
@@ -70,15 +71,6 @@ namespace {
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 
-  // Convert linear subgroup ID to 2D coordinates
-  // TODO: Delinearize for nD
-  SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
-                                           Location loc, Value sgID,
-                                           Value sgDimX, Value sgDimY) const {
-    return {rewriter.create<index::DivUOp>(loc, sgID, sgDimY),
-            rewriter.create<index::RemUOp>(loc, sgID, sgDimY)};
-  }
-
   // Calculate offset for each subgroup
   SmallVector<OpFoldResult>
   calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
@@ -144,7 +136,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
 
     // TODO : Handle order attribute
     // Get the subgroup ID
-    auto linearSgId = rewriter.create<gpu::SubgroupIdOp>(loc, nullptr);
+    auto linearSgId =
+        rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
 
     // Create constants for layout dimensions
     SmallVector<Value> sgLayoutDim(sgLayout.size());
@@ -156,9 +149,11 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       sgDataDim[i] = rewriter.create<arith::ConstantIndexOp>(loc, sgShape[i]);
     }
 
-    // Delinearize the 1D subgroup id into 2d
-    SmallVector<Value> sgIds = delinearizeSubgroupId(
-        rewriter, loc, linearSgId, sgLayoutDim[0], sgLayoutDim[1]);
+    auto deLinearizeSgId =
+        affine::delinearizeIndex(rewriter, loc, linearSgId, sgLayoutDim);
+    if (failed(deLinearizeSgId))
+      return failure();
+    SmallVector<Value> sgIds = *deLinearizeSgId;
 
     // Calculate distribution unit shape and local offsets for subgroup
     SmallVector<int64_t> distUnitShape(sgLayout.size());
@@ -267,9 +262,9 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
     if (!originalLayout)
       return failure();
 
+    size_t i = 0;
     SmallVector<Value> newDpasOps;
     for (auto aVec : adaptor.getLhs()) {
-      size_t i = 0;
       for (auto bVec : adaptor.getRhs()) {
 
         llvm::SmallVector<Value> operands({aVec, bVec});
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 321cc0510a24c..23fdffc220ecb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -6,7 +6,9 @@ gpu.module @test_round_robin_assignment {
   gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) {
       // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
       // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-NOT: xegpu.create_nd_tdesc
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+        -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       gpu.return
     }
 
@@ -17,18 +19,26 @@ gpu.module @test_round_robin_assignment {
       // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
       // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
       // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
-      %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
+      // CHECK-NOT: xegpu.load_nd
+      %load =  xegpu.load_nd %tdesc
+        : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+        -> vector<24x32xf32>
       gpu.return
     }
 
   // CHECK-LABEL: test_store_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_store_nd(%src: memref<24x32xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+        -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
       // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-      %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<24x32xf32>
-      xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      // CHECK-NOT : xegpu.store_nd
+      %load = xegpu.load_nd %tdesc
+        : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+        -> vector<24x32xf32>
+      xegpu.store_nd %load, %tdesc
+        : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       gpu.return
   }
 
@@ -38,7 +48,9 @@ gpu.module @test_round_robin_assignment {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
     // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    %update = xegpu.update_nd_offset %tdesc, [0, 16] :  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.update_nd_offset
+    %update = xegpu.update_nd_offset %tdesc, [0, 16]
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
 
@@ -47,28 +59,45 @@ gpu.module @test_round_robin_assignment {
   gpu.func @test_dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
     // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
     // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
     // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
     // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
     // CHECK-COUNT-4:  xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
     // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.create_nd_tdesc
     // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
     // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
     // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %load_b =  xegpu.load_nd %tdesc_b:  !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>> -> vector<8x8xf32>
-    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32> -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
+    // CHECK-NOT: xegpu.dpas
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32>
+      -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      -> vector<8x8xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32>
+      -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      -> vector<8x8xf32>
+    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
+      -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
     gpu.return
   }
 
   // CHECK-LABEL: test_prefetch_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
-    // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}}
+    // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-NOT: xegpu.prefetch_nd
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 3bd95ee775db3..5feb0da1ddfae 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
+//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
+//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
 gpu.module @test_1_1_assignment {
   // CHECK-LABEL: test_create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
@@ -8,8 +10,8 @@ gpu.module @test_1_1_assignment {
   // CHECK: %[[C12:.*]] = arith.constant 12 : index
   // CHECK: %[[C4:.*]] = arith.constant 4 : index
   // CHECK: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK: %[[DIV:.*]] = index.divu %[[SGID]], %[[C4]]
-  // CHECK: %[[REM:.*]] = index.remu %[[SGID]], %[[C4]]
+  // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
+  // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
   // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
   // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
@@ -18,7 +20,8 @@ gpu.module @test_1_1_assignment {
   // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
   // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
   // CHECK: gpu.return
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+    -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   gpu.return
   }
 
@@ -30,8 +33,11 @@ gpu.module @test_1_1_assignment {
     // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
     // CHECK-SAME: -> vector<12x8xf32>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-    %load =  xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
     gpu.return
   }
 
@@ -45,9 +51,13 @@ gpu.module @test_1_1_assignment {
     // CHECK-SAME: -> vector<12x8xf32>
     // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
     // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-    %load = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
-    xegpu.store_nd %load, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load = xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    xegpu.store_nd %load, %tdesc
+      : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     gpu.return
 }
 
@@ -58,8 +68,10 @@ gpu.func @test_update_nd(%src: memref<24x32xf32>){
   // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
   // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
   // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-  %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+    -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  %update = xegpu.update_nd_offset %tdesc, [0, 16]
+    : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
   gpu.return
 }
 
@@ -80,11 +92,19 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
     // CHECK-SAME: {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-    %load_a =  xegpu.load_nd %tdesc_a: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>> -> vector<24x32xf32>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32> -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
-    %load_b =  xegpu.load_nd %tdesc_b: !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>> -> vector<32x24xf32>
-    %dpas = xegpu.dpas %load_a, %load_b {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>} : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
+      -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [8, 12], lane_layout = [8, 2], lane_data = [1, 1]>>
+      -> vector<32x24xf32>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
+      : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
 
@@ -95,8 +115,10 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
     // CHECK: xegpu.prefetch_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-    xegpu.prefetch_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    xegpu.prefetch_nd %tdesc
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     gpu.return
   }
 }

>From 64259613115e79ec92f8cc717a42ccc3d0a94b70 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 16 May 2025 20:38:14 +0000
Subject: [PATCH 10/18] add support for nD

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 39 +++++++------------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  2 +-
 2 files changed, 16 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 20fc6951c9481..68410f0f443f8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -77,19 +77,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
                          const SmallVector<OpFoldResult> &originalOffsets,
                          const SmallVector<Value> &localOffset,
                          const SmallVector<int64_t> &distUnitBaseAddr) const {
-
-    Value constOffsetX =
-        rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[0]);
-    Value constOffsetY =
-        rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[1]);
-
-    Value offsetX =
-        rewriter.createOrFold<index::AddOp>(loc, localOffset[0], constOffsetX);
-    Value offsetY =
-        rewriter.createOrFold<index::AddOp>(loc, localOffset[1], constOffsetY);
-
-    size_t lastDimIndex = originalOffsets.size() - 1;
-    size_t secondLastDimIndex = lastDimIndex - 1;
+    assert(localOffset.size() == distUnitBaseAddr.size() &&
+           "localOffset and distUnitBaseAddr must have the same rank");
 
     // Convert originalOffsets to Value
     auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
@@ -102,18 +91,20 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       llvm_unreachable("Unsupported OpFoldResult kind");
     };
 
-    Value origOffsetX =
-        getValueFromOpFoldResult(originalOffsets[secondLastDimIndex]);
-    Value origOffsetY = getValueFromOpFoldResult(originalOffsets[lastDimIndex]);
-    Value globalOffsetX =
-        rewriter.createOrFold<index::AddOp>(loc, origOffsetX, offsetX);
-    Value globalOffsetY =
-        rewriter.createOrFold<index::AddOp>(loc, origOffsetY, offsetY);
-
     SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
                                             originalOffsets.end());
-    globalOffsets[secondLastDimIndex] = globalOffsetX;
-    globalOffsets[lastDimIndex] = globalOffsetY;
+    size_t rank = localOffset.size();
+    for (size_t i = 0; i < rank; ++i) {
+      size_t dimIdx = originalOffsets.size() - rank + i;
+      Value constOffset =
+          rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
+      Value offset =
+          rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
+      Value origOffset = getValueFromOpFoldResult(originalOffsets[dimIdx]);
+      Value globalOffset =
+          rewriter.createOrFold<index::AddOp>(loc, origOffset, offset);
+      globalOffsets[dimIdx] = globalOffset;
+    }
 
     return globalOffsets;
   }
@@ -283,7 +274,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
         tmpC = rewriter.create<xegpu::DpasOp>(
             loc, resTy, operands,
             llvm::ArrayRef<NamedAttribute>(
-                {"layout", originalLayout.dropSgLayoutAndData()}));
+                {"layout_result_0", originalLayout.dropSgLayoutAndData()}));
         newDpasOps.push_back(tmpC);
       }
     }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 5feb0da1ddfae..5d9ddb3ef1e97 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -90,7 +90,7 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
     // CHECK-SAME: -> vector<8x12xf32>
     // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
-    // CHECK-SAME: {layout =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+    // CHECK-SAME: {layout_result_0 =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
     // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
       -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>

>From dfb25ad5976d50beac0c2da0c50cc0c8e8bf36f9 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 19 May 2025 18:40:38 +0000
Subject: [PATCH 11/18] Add TODO and failing test case

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp        |  5 +++++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir           | 11 ++++++++++-
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 68410f0f443f8..d0b37cf1f69be 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -68,6 +68,11 @@ namespace {
 ///
 /// Since the 24x24 matrix is divided into 8x8 distribution units, there will be
 /// 9 distribution units (3x3) in total. Hence the 9 subgroup level operations.
+
+/// The pass currently has entire distribution logic in the WgToSgCreateNdOp
+/// pattern and all the other ops just follow.
+/// TODO: Decouple the distribution logic from WgToSgCreateNdOp for all the
+/// ops in the pass.
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 5d9ddb3ef1e97..94c5f59447e8b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -121,4 +121,13 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
     gpu.return
   }
-}
+
+  // CHECK-LABEL: test_dpas_with_no_create_nd_desc
+  gpu.func @test_dpas_with_no_create_nd_desc(%a: vector<24x32xf32>, %b: vector<32x24xf32>) {
+    // CHECK-NOT: vector<12x12xf32>
+    %dpas = xegpu.dpas %a, %b
+      {layout =  #xegpu.layout<sg_layout = [2, 2], sg_data = [12, 12], lane_layout = [2, 2], lane_data = [1, 1]>}
+      : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+    gpu.return
+  }
+}
\ No newline at end of file

>From cd4f9fcdbd49b8905071d6abc402951bca69cd88 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 19 May 2025 18:46:21 +0000
Subject: [PATCH 12/18] Add newline

---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 94c5f59447e8b..b7a9cd89030db 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -130,4 +130,4 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
-}
\ No newline at end of file
+}

>From 4612f646b6f82ba835b0f5a1f9994124fe7641c0 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 19 May 2025 20:43:02 +0000
Subject: [PATCH 13/18] Use min shape for dist_unit

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d0b37cf1f69be..82928a6c38771 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -155,7 +155,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     SmallVector<int64_t> distUnitShape(sgLayout.size());
     SmallVector<Value> localOffset(sgLayout.size());
     for (size_t i = 0; i < sgLayout.size(); i++) {
-      distUnitShape[i] = sgLayout[i] * sgShape[i];
+      distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
       localOffset[i] =
           rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
     }

>From a092bd34d2d28462592e72b4d47b5decb64d8756 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 19 May 2025 22:44:11 +0000
Subject: [PATCH 14/18] Address Feedback

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 42 +++++++++++++-----
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 43 ++++++++++++++++++-
 2 files changed, 73 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 82928a6c38771..53b76e1c580a7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -81,7 +81,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
                          const SmallVector<OpFoldResult> &originalOffsets,
                          const SmallVector<Value> &localOffset,
-                         const SmallVector<int64_t> &distUnitBaseAddr) const {
+                         const SmallVector<int64_t> &distUnitBaseAddr,
+                         const SmallVector<int64_t> &distUnitShape) const {
     assert(localOffset.size() == distUnitBaseAddr.size() &&
            "localOffset and distUnitBaseAddr must have the same rank");
 
@@ -105,9 +106,13 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
           rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
       Value offset =
           rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
+      Value modValue =
+          rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
+      Value offsetMod =
+          rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
       Value origOffset = getValueFromOpFoldResult(originalOffsets[dimIdx]);
       Value globalOffset =
-          rewriter.createOrFold<index::AddOp>(loc, origOffset, offset);
+          rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
       globalOffsets[dimIdx] = globalOffset;
     }
 
@@ -125,10 +130,27 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       return failure();
     Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
-    SmallVector<int64_t> sgShape =
-        llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
-    SmallVector<int64_t> sgLayout =
-        llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
+    SmallVector<int64_t> sgLayout;
+    if (auto sgLayoutAttr = layout.getSgLayout()) {
+      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+    } else {
+      // sgLayout must be present for workgroup-level distribution.
+      op.emitError("sgLayout attribute is required in layout");
+      return failure();
+    }
+
+    SmallVector<int64_t> sgShape;
+    if (auto sgDataAttr = layout.getSgData()) {
+      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+    } else {
+      assert(wgShape.size() == sgLayout.size() &&
+             "sgLayout and wgShape must have the same rank");
+      sgShape.reserve(wgShape.size());
+      for (size_t i = 0; i < wgShape.size(); ++i) {
+        assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero");
+        sgShape.push_back(wgShape[i] / sgLayout[i]);
+      }
+    }
 
     // TODO : Handle order attribute
     // Get the subgroup ID
@@ -168,8 +190,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     SmallVector<Value> newCreateNdOps;
     for (SmallVector<int64_t> distUnitBaseAddr :
          StaticTileOffsetRange(wgShape, distUnitShape)) {
-      SmallVector<OpFoldResult> globalOffsets = calculateGlobalOffsets(
-          rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr);
+      SmallVector<OpFoldResult> globalOffsets =
+          calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
+                                 distUnitBaseAddr, distUnitShape);
 
       auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
           loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
@@ -258,11 +281,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
     if (!originalLayout)
       return failure();
 
-    size_t i = 0;
     SmallVector<Value> newDpasOps;
+    size_t i = 0;
     for (auto aVec : adaptor.getLhs()) {
       for (auto bVec : adaptor.getRhs()) {
-
         llvm::SmallVector<Value> operands({aVec, bVec});
         Value tmpC;
         if (op.getAcc()) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index b7a9cd89030db..7e89ada934071 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -14,9 +14,14 @@ gpu.module @test_1_1_assignment {
   // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
   // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
   // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
+  // CHECK: %[[C24:.*]] = arith.constant 24 : index
+  // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]]
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[ADD1:.*]] = index.add %[[MUL1]], %[[C0]]
-  // CHECK: %[[ADD2:.*]] = index.add %[[MUL2]], %[[C0]]
+  // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]]
+  // CHECK: %[[C0_1:.*]] = arith.constant 0 : index
+  // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]]
   // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
   // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
   // CHECK: gpu.return
@@ -108,6 +113,40 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
+
+// CHECK-LABEL: test_dpas_no_sg_data
+// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
+// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
+gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
+    // CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
+    // CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
+    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<12x8xf32>
+    // CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
+    // CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
+    // CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<8x12xf32>
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
+    // CHECK-SAME: {layout_result_0 =  #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
+      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
+    %load_a =  xegpu.load_nd %tdesc_a
+      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
+      -> vector<24x32xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
+      -> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
+    %load_b =  xegpu.load_nd %tdesc_b
+      : !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
+      -> vector<32x24xf32>
+    %dpas = xegpu.dpas %load_a, %load_b
+      {layout =  #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
+      : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
+    gpu.return
+  }
+
   // CHECK-LABEL: test_prefetch_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {

>From 1d5f75a46b046a25ca99f2fe531060754ceeda58 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 20 May 2025 14:35:15 +0000
Subject: [PATCH 15/18] Use getValueOrCreateConstantIndexOp

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 19 +++++--------------
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |  6 ++++--
 2 files changed, 9 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 53b76e1c580a7..0f752296351b0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -8,6 +8,7 @@
 #include "mlir/Dialect/XeGPU/Transforms/Passes.h"
 
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Index/IR/IndexOps.h"
@@ -86,17 +87,6 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     assert(localOffset.size() == distUnitBaseAddr.size() &&
            "localOffset and distUnitBaseAddr must have the same rank");
 
-    // Convert originalOffsets to Value
-    auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
-      if (auto val = ofr.dyn_cast<Value>())
-        return val;
-      if (auto attr = ofr.dyn_cast<Attribute>()) {
-        int64_t staticOffset = cast<IntegerAttr>(attr).getInt();
-        return rewriter.create<arith::ConstantIndexOp>(loc, staticOffset);
-      }
-      llvm_unreachable("Unsupported OpFoldResult kind");
-    };
-
     SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
                                             originalOffsets.end());
     size_t rank = localOffset.size();
@@ -110,7 +100,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
           rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
       Value offsetMod =
           rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
-      Value origOffset = getValueFromOpFoldResult(originalOffsets[dimIdx]);
+      Value origOffset = getValueOrCreateConstantIndexOp(
+          rewriter, loc, originalOffsets[dimIdx]);
       Value globalOffset =
           rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
       globalOffsets[dimIdx] = globalOffset;
@@ -135,8 +126,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
     } else {
       // sgLayout must be present for workgroup-level distribution.
-      op.emitError("sgLayout attribute is required in layout");
-      return failure();
+      return rewriter.notifyMatchFailure(
+          op, "sgLayout attribute is required in layout");
     }
 
     SmallVector<int64_t> sgShape;
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 23fdffc220ecb..bee026eb2084d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -15,7 +15,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: test_load_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_load_nd_tdesc(%src: memref<24x32xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+        -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
       // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
       // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
@@ -45,7 +46,8 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: test_update_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
   gpu.func @test_update_nd(%src: memref<24x32xf32>){
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
+      ->  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
     // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.update_nd_offset

>From ed124fc60d2d95c1c3b1ace0bc5080f1ebc6fd0f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 20 May 2025 15:23:45 +0000
Subject: [PATCH 16/18] Remove braces for single line if

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp        | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0f752296351b0..c16ad7cb3c3fc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -121,19 +121,18 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       return failure();
     Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
+    // sgLayout must be present for workgroup-level distribution.
     SmallVector<int64_t> sgLayout;
-    if (auto sgLayoutAttr = layout.getSgLayout()) {
+    if (auto sgLayoutAttr = layout.getSgLayout())
       sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    } else {
-      // sgLayout must be present for workgroup-level distribution.
+    else
       return rewriter.notifyMatchFailure(
           op, "sgLayout attribute is required in layout");
-    }
 
     SmallVector<int64_t> sgShape;
-    if (auto sgDataAttr = layout.getSgData()) {
+    if (auto sgDataAttr = layout.getSgData())
       sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
-    } else {
+    else {
       assert(wgShape.size() == sgLayout.size() &&
              "sgLayout and wgShape must have the same rank");
       sgShape.reserve(wgShape.size());

>From 30f444ea866c199949800d74352312e83e0fe1e8 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 20 May 2025 15:32:15 +0000
Subject: [PATCH 17/18] Add braces for uniformity

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c16ad7cb3c3fc..3bf76af674ba0 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -130,9 +130,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
           op, "sgLayout attribute is required in layout");
 
     SmallVector<int64_t> sgShape;
-    if (auto sgDataAttr = layout.getSgData())
+    if (auto sgDataAttr = layout.getSgData()) {
       sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
-    else {
+    } else {
       assert(wgShape.size() == sgLayout.size() &&
              "sgLayout and wgShape must have the same rank");
       sgShape.reserve(wgShape.size());

>From b9ceb8afe39b0ec5ad51dadde48d038a80129277 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 20 May 2025 21:21:25 +0000
Subject: [PATCH 18/18] Add IndexDialect dependency

---
 mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 837303b04e9d7..7d9b5584b0b2b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
   LINK_LIBS PUBLIC
   MLIRAffineUtils
   MLIRIR
+  MLIRIndexDialect
   MLIRMemRefDialect
   MLIRXeGPUDialect
   MLIRPass



More information about the Mlir-commits mailing list