[Mlir-commits] [mlir] [MLIR][XeGPU] Support order attribute and add pattern for vector.transpose in WgToSg Pass (PR #165307)

Nishant Patel llvmlistbot at llvm.org
Mon Oct 27 12:54:17 PDT 2025


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/165307

None

>From 2124bbffea2b0e01d4d6440651e846bd1a9c2067 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 23 Sep 2025 22:40:02 +0000
Subject: [PATCH 1/5] Add pattern for transpose

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 55 ++++++++++++++++---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 49 ++++++++++++++---
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 11 ++++
 3 files changed, 101 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 94c5509fd7c29..4cd9b5a9f451b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -309,14 +309,55 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
       return failure();
   }
 
-  // delinearize Ids
-  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
-  if (failed(maybeIds))
-    return failure();
-  SmallVector<Value> sgIds = *maybeIds;
+  // Get the order attribute, default to row-major if not present
+  SmallVector<int64_t> order;
+  if (getOrder()) {
+    order = llvm::map_to_vector(getOrder().asArrayRef(), [](int v) { return static_cast<int64_t>(v); });
+  } else {
+    auto range = llvm::reverse(llvm::seq<int64_t>(0, sgLayout.size()));
+    order = llvm::to_vector(range);
+  }
 
-  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
-                                  shape);
+  // Check if order is default row-major (reverse identity)
+  bool isRowMajor = true;
+  for (size_t i = 0, n = order.size(); i < n; ++i)
+    if (order[i] != static_cast<int64_t>(n) - 1 - static_cast<int64_t>(i))
+      isRowMajor = false;
+
+  SmallVector<Value> sgIds;
+  if (isRowMajor) {
+    // Use original delinearization for row-major
+    auto maybeIds = affine::delinearizeIndex(
+        builder, loc, linearId,
+        llvm::map_to_vector(sgLayout, [&](int64_t d) -> Value {
+          return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+        }));
+    if (failed(maybeIds))
+      return failure();
+    sgIds = maybeIds.value();
+  } else {
+    // Permute sgLayout according to order for delinearization
+    SmallVector<int64_t> permutedLayout(order.size());
+    for (size_t i = 0; i < order.size(); ++i)
+      permutedLayout[i] = sgLayout[order[i]];
+
+    // Delinearize the linear subgroup id in the requested order
+    auto maybePermutedSgId = affine::delinearizeIndex(
+        builder, loc, linearId,
+        llvm::map_to_vector(permutedLayout, [&](int64_t d) -> Value {
+          return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+        }));
+    if (failed(maybePermutedSgId))
+      return failure();
+    SmallVector<Value> permutedSgId = maybePermutedSgId.value();
+
+    // Compute the inverse permutation to map back to physical order
+    sgIds.resize(order.size());
+    for (size_t i = 0; i < order.size(); ++i)
+      sgIds[order[i]] = permutedSgId[i];
+  }
+
+  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, shape);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d7592fed6d186..0580a6d61b359 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1027,6 +1027,44 @@ struct WgToSgVectorShapeCastOp
   }
 };
 
+// This pattern transforms vector.transpose ops to work at subgroup level.
+struct WgToSgVectorTransposeOp
+    : public OpConversionPattern<vector::TransposeOp> {
+  using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getDistributeLayoutAttr(op.getResult());
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newTransposeOps;
+    for (auto src : adaptor.getVector()) {
+      auto newTranspose = rewriter.create<vector::TransposeOp>(
+          op.getLoc(), newResultType, src, op.getPermutation());
+      if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+          !layout.getEffectiveInstDataAsInt().empty())
+        xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
+                                       layout.dropSgLayoutAndData());
+      newTransposeOps.push_back(newTranspose.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newTransposeOps});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -1040,7 +1078,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
-           WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(
+           WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
+           WgToSgVectorTransposeOp>(
           patterns.getContext());
 }
 } // namespace xegpu
@@ -1168,7 +1207,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
+                               vector::TransposeOp, vector::BroadcastOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
@@ -1190,11 +1230,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::BroadcastOp>(
-      [=](vector::BroadcastOp op) -> bool {
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
-      });
-
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 3478a9b91da5f..5cb8a2df54652 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -424,4 +424,15 @@ gpu.module @test_distribution {
     %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_transpose
+  gpu.func @vector_transpose(%src: memref<256x32xf32>) {
+      %tdesc = xegpu.create_nd_tdesc %src : memref<256x32xf32>
+          -> !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [1, 32], lane_data = [1, 1]>>
+      %load = xegpu.load_nd %tdesc[0, 0]
+          : !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [1, 32], lane_data = [1, 1]>>
+          -> vector<256x32xf32>
+      %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<256x32xf32> to vector<32x256xf32>
+      gpu.return
+  }
 }

>From 88ef6c9becfd9e2c7553ecbd8c43e9bc51073c23 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 25 Sep 2025 18:33:43 +0000
Subject: [PATCH 2/5] Add pattern for 2D transpose

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 77 +++++++------------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 15 ++--
 2 files changed, 34 insertions(+), 58 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 4cd9b5a9f451b..dc4a24bc4f6d6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -280,15 +280,31 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
     return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
                          llvm::reverse(order.asArrayRef())));
   };
-  if (!hasDefaultOrder())
-    return mlir::emitError(loc, "order attribute is currently not supported.");
 
   auto dims =
       llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
         return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
       });
 
-  return affine::delinearizeIndex(builder, loc, linearId, dims);
+  if (hasDefaultOrder())
+    return affine::delinearizeIndex(builder, loc, linearId, dims);
+  else if (getOrder() && getOrder().size() == 2 &&
+           getOrder().asArrayRef()[0] == 0 && getOrder().asArrayRef()[1] == 1) {
+    // If order is [0, 1], reverse the dims for delinearization, then reverse
+    // the result.
+    // This is a temporary solution for 2D sg_layout with order [0, 1].
+    // A complete solution requires generating more affine maps for
+    // delinearization based on the order attribute.
+    assert(dims.size() == 2 && "expected 2D sg_layout.");
+    SmallVector<Value> reversedDims = {dims[1], dims[0]};
+    auto maybeIds =
+        affine::delinearizeIndex(builder, loc, linearId, reversedDims);
+    if (failed(maybeIds))
+      return failure();
+    SmallVector<Value> ids = maybeIds.value();
+    std::reverse(ids.begin(), ids.end());
+    return ids;
+  }
 }
 
 /// Implements DistributeLayoutAttr::getOffsets to generate
@@ -309,55 +325,14 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
       return failure();
   }
 
-  // Get the order attribute, default to row-major if not present
-  SmallVector<int64_t> order;
-  if (getOrder()) {
-    order = llvm::map_to_vector(getOrder().asArrayRef(), [](int v) { return static_cast<int64_t>(v); });
-  } else {
-    auto range = llvm::reverse(llvm::seq<int64_t>(0, sgLayout.size()));
-    order = llvm::to_vector(range);
-  }
-
-  // Check if order is default row-major (reverse identity)
-  bool isRowMajor = true;
-  for (size_t i = 0, n = order.size(); i < n; ++i)
-    if (order[i] != static_cast<int64_t>(n) - 1 - static_cast<int64_t>(i))
-      isRowMajor = false;
-
-  SmallVector<Value> sgIds;
-  if (isRowMajor) {
-    // Use original delinearization for row-major
-    auto maybeIds = affine::delinearizeIndex(
-        builder, loc, linearId,
-        llvm::map_to_vector(sgLayout, [&](int64_t d) -> Value {
-          return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
-        }));
-    if (failed(maybeIds))
-      return failure();
-    sgIds = maybeIds.value();
-  } else {
-    // Permute sgLayout according to order for delinearization
-    SmallVector<int64_t> permutedLayout(order.size());
-    for (size_t i = 0; i < order.size(); ++i)
-      permutedLayout[i] = sgLayout[order[i]];
-
-    // Delinearize the linear subgroup id in the requested order
-    auto maybePermutedSgId = affine::delinearizeIndex(
-        builder, loc, linearId,
-        llvm::map_to_vector(permutedLayout, [&](int64_t d) -> Value {
-          return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
-        }));
-    if (failed(maybePermutedSgId))
-      return failure();
-    SmallVector<Value> permutedSgId = maybePermutedSgId.value();
-
-    // Compute the inverse permutation to map back to physical order
-    sgIds.resize(order.size());
-    for (size_t i = 0; i < order.size(); ++i)
-      sgIds[order[i]] = permutedSgId[i];
-  }
+  // delinearize Ids
+  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+  if (failed(maybeIds))
+    return failure();
+  SmallVector<Value> sgIds = *maybeIds;
 
-  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape, shape);
+  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+                                  shape);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index a99396d19f26a..2fb9f4ec972e1 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -467,12 +467,13 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: vector_transpose
   gpu.func @vector_transpose(%src: memref<256x32xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src : memref<256x32xf32>
-          -> !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [1, 32], lane_data = [1, 1]>>
-      %load = xegpu.load_nd %tdesc[0, 0]
-          : !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [1, 32], lane_data = [1, 1]>>
-          -> vector<256x32xf32>
-      %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<256x32xf32> to vector<32x256xf32>
-      gpu.return
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x32xf32>
+        -> !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [1, 32], lane_data = [1, 1], order =[0, 1]>>
+    %load = xegpu.load_nd %tdesc[0, 0]
+        : !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [1, 32], lane_data = [1, 1], order =[0, 1]>>
+        -> vector<256x32xf32>
+    //CHECK: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<64x32xf32> to vector<32x64xf32>
+    %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x32xf32> to vector<32x256xf32>
+    gpu.return
   }
 }

>From f69e709a7b85a3b3f607b8779999f13da2a682dd Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 21 Oct 2025 20:26:40 +0000
Subject: [PATCH 3/5] support nD order

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  95 +++++++----
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  53 +++++-
 .../Dialect/XeGPU/xegpu-attr-interface.mlir   |  41 ++---
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |  76 ++++-----
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    |  65 ++++++--
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 155 +++++++++++-------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 145 ++++++++--------
 7 files changed, 382 insertions(+), 248 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 56d5d68e1c9c6..c56f7c2d78c21 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -270,42 +270,77 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
 FailureOr<SmallVector<Value>>
 LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
                                   Value linearId) {
-  // delinearizeSubgroupId is only available for
-  // workgroup-level layout attribute
   if (!isForWorkgroup())
     return failure();
 
-  // TODO: handle order attribute
-  auto hasDefaultOrder = [&]() {
-    DenseI32ArrayAttr order = getOrder();
-    return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
-                         llvm::reverse(order.asArrayRef())));
-  };
+  SmallVector<int64_t> sgLayoutInt = getEffectiveSgLayoutAsInt();
+  DenseI32ArrayAttr orderAttr = getOrder();
+
+  // Handle order attribute
+  SmallVector<int64_t> order;
+  if (orderAttr && !orderAttr.empty()) {
+    order = llvm::to_vector(
+        llvm::map_range(orderAttr.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+  } else {
+    // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
+    order = llvm::to_vector(
+        llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
+  }
 
-  auto dims =
-      llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
-        return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
-      });
+  // Validate order
+  if (order.size() != sgLayoutInt.size()) {
+    return failure();
+  }
 
-  if (hasDefaultOrder())
-    return affine::delinearizeIndex(builder, loc, linearId, dims);
-  else if (getOrder() && getOrder().size() == 2 &&
-           getOrder().asArrayRef()[0] == 0 && getOrder().asArrayRef()[1] == 1) {
-    // If order is [0, 1], reverse the dims for delinearization, then reverse
-    // the result.
-    // This is a temporary solution for 2D sg_layout with order [0, 1].
-    // A complete solution requires generating more affine maps for
-    // delinearization based on the order attribute.
-    assert(dims.size() == 2 && "expected 2D sg_layout.");
-    SmallVector<Value> reversedDims = {dims[1], dims[0]};
-    auto maybeIds =
-        affine::delinearizeIndex(builder, loc, linearId, reversedDims);
-    if (failed(maybeIds))
-      return failure();
-    SmallVector<Value> ids = maybeIds.value();
-    std::reverse(ids.begin(), ids.end());
-    return ids;
+  SmallVector<Value> result(sgLayoutInt.size());
+  Value remaining = linearId;
+
+  /// Process dimensions in the order they appear in the order array
+  /// The first dimension in order is the fastest-changing
+  ///
+  /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
+  /// 
+  /// Initial: remaining=22, result=[?,?,?]
+  /// 
+  /// i=0 (process columns, dimIdx=2, dimSize=4):
+  ///   result[2] = 22 % 4 = 2  (column coordinate)
+  ///   remaining = 22 / 4 = 5  (5 complete groups of 4 columns processed)
+  /// 
+  /// i=1 (process rows, dimIdx=1, dimSize=4):
+  ///   result[1] = 5 % 4 = 1   (row coordinate) 
+  ///   remaining = 5 / 4 = 1   (1 complete group of 4 rows processed)
+  /// 
+  /// i=2 (process layers, dimIdx=0, dimSize=2):
+  ///   result[0] = 1 % 2 = 1   (layer coordinate)
+  ///   (no remaining update - last iteration)
+  /// 
+  /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
+  for (size_t i = 0; i < order.size(); ++i) {
+    int64_t dimIdx = order[i];
+    int64_t dimSize = sgLayoutInt[dimIdx];
+
+    Value dimSizeVal =
+        builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
+
+    /// Extract the coordinate for this dimension using modulo operation
+    /// This gives us "how far within this dimension" we are
+    /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within this
+    /// dimension)
+    result[dimIdx] =
+        builder.createOrFold<index::RemUOp>(loc, remaining, dimSizeVal);
+
+    /// Update remaining for the next dimension by removing what we've already
+    /// processed. Division tells us "how many complete groups of this dimension
+    /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
+    /// completed 5 groups of 4) Skip this for the last iteration since there's
+    /// no next dimension to process
+    if (i < order.size() - 1) {
+      remaining =
+          builder.createOrFold<index::DivUOp>(loc, remaining, dimSizeVal);
+    }
   }
+  return result;
 }
 
 /// Implements DistributeLayoutAttr::getOffsets to generate
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6177098c903fe..d97c66f5b134a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1233,14 +1233,60 @@ struct WgToSgVectorTransposeOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
+    // Get the source layout for validation
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(op.getVector());
+    if (!sourceLayout || !sourceLayout.isForWorkgroup())
+      return failure();
+
+    // Validate that result layout is transpose of source layout
+    SmallVector<int64_t> sourceSgLayout =
+        sourceLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> sourceSgData = sourceLayout.getEffectiveSgDataAsInt();
+    SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> resultSgData = layout.getEffectiveSgDataAsInt();
+
+    ArrayRef<int64_t> permutation = op.getPermutation();
+
+    // Check that sgLayout and sgData are properly transposed
+    if (sourceSgLayout.size() != resultSgLayout.size() ||
+        sourceSgData.size() != resultSgData.size() ||
+        sourceSgLayout.size() != permutation.size()) {
+      return rewriter.notifyMatchFailure(
+          op, "Source and result layouts must have same rank as permutation");
+    }
+
+    // Validate sgLayout transpose
+    for (size_t i = 0; i < permutation.size(); ++i) {
+      int64_t srcDim = permutation[i];
+      if (srcDim < 0 || srcDim >= static_cast<int64_t>(sourceSgLayout.size())) {
+        return rewriter.notifyMatchFailure(op, "Invalid permutation index");
+      }
+      if (resultSgLayout[i] != sourceSgLayout[srcDim]) {
+        return rewriter.notifyMatchFailure(
+            op, "Result sgLayout is not transpose of source sgLayout according "
+                "to permutation");
+      }
+    }
+
+    // Validate sgData transpose
+    for (size_t i = 0; i < permutation.size(); ++i) {
+      int64_t srcDim = permutation[i];
+      if (resultSgData[i] != sourceSgData[srcDim]) {
+        return rewriter.notifyMatchFailure(
+            op, "Result sgData is not transpose of source sgData according to "
+                "permutation");
+      }
+    }
+
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
     SmallVector<Value> newTransposeOps;
     for (auto src : adaptor.getVector()) {
-      auto newTranspose = rewriter.create<vector::TransposeOp>(
-          op.getLoc(), newResultType, src, op.getPermutation());
+      auto newTranspose = vector::TransposeOp::create(
+          rewriter, op.getLoc(), newResultType, src, op.getPermutation());
       if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
           !layout.getEffectiveInstDataAsInt().empty())
         xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
@@ -1267,7 +1313,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
            WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
            WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
-           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(patterns.getContext());
+           WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
index b73bc69393dab..7d785437948ea 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -1,33 +1,34 @@
 // RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s
 
-//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
 gpu.module @test {
   gpu.func @slice_attr() -> vector<128xindex> {
-    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
-    //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
-    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
-    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
-    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C8:.*]]
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU]], %[[C4:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
+    // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+    // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+    // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
+    // CHECK-DAG: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ADD]] : vector<32xindex> to vector<128xindex>
     %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
     gpu.return %step : vector<128xindex>
   }
 
   gpu.func @nested_slice_attr() -> vector<128xindex> {
-    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
-    //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
-    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
-    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
-    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[DIVU1:.*]] = index.divu %[[SGID]], %[[C1:.*]]
+    // CHECK-DAG: %[[DIVU2:.*]] = index.divu %[[DIVU1]], %[[C8:.*]]
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[DIVU2]], %[[C4:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C32:.*]]
+    // CHECK-DAG: %[[MOD:.*]] = index.remu %[[MUL]], %[[C128:.*]]
+    // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
+    // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
+    // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
+    // CHECK-DAG: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ADD]] : vector<32xindex> to vector<128xindex>
     %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
     gpu.return %0 : vector<128xindex>
   }
 
-}
\ No newline at end of file
+}
+
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index d2d250cbe0f66..965eb08bcb506 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,14 +1,17 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
-#map = affine_map<()[s0] -> (s0 floordiv 4)>
-#map1 = affine_map<()[s0] -> (s0 mod 4)>
-
 gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
+      // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+      // CHECK: %[[C4:.*]] = arith.constant 4 : index
+      // CHECK: %[[IDX:.*]] = index.remu %[[SGID]], %[[C4]]
+      // CHECK: %[[IDY_DIV:.*]] = index.divu %[[SGID]], %[[C4]]
+      // CHECK: %[[C8:.*]] = arith.constant 8 : index
+      // CHECK: %[[IDY:.*]] = index.remu %[[IDY_DIV]], %[[C8]]
       // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
-      // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-NOT: xegpu.create_nd_tdesc
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -16,22 +19,23 @@ gpu.module @test_round_robin_assignment {
     }
 
   // CHECK-LABEL: create_nd_tdesc_with_shared_data
-  // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
-    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]]
-    //CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]]
-    //CHECK: [[C16:%.+]] = arith.constant 16 : index
-    //CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]]
-    //CHECK: [[C64:%.+]] = arith.constant 64 : index
-    //CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
-    //CHECK: [[C0:%.+]] = arith.constant 0 : index
-    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[offY:%.+]] = index.remu [[LY]], [[C128]]
-    //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
-    //CHECK: [[offX:%.+]] = index.remu [[LX]], [[C64_2]]
-    //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK: %[[C4:.*]] = arith.constant 4 : index
+    // CHECK: %[[IDX:.*]] = index.remu %[[SGID]], %[[C4]]
+    // CHECK: %[[IDY_DIV:.*]] = index.divu %[[SGID]], %[[C4]]
+    // CHECK: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK: %[[IDY:.*]] = index.remu %[[IDY_DIV]], %[[C8]]
+    // CHECK: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK: %[[LY:.*]] = index.mul %[[IDY]], %[[C16]]
+    // CHECK: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK: %[[LX:.*]] = index.mul %[[IDX]], %[[C64]]
+    // CHECK: %[[C128:.*]] = arith.constant 128 : index
+    // CHECK: %[[OFFY:.*]] = index.remu %[[LY]], %[[C128]]
+    // CHECK: %[[C64_1:.*]] = arith.constant 64 : index
+    // CHECK: %[[OFFX:.*]] = index.remu %[[LX]], %[[C64_1]]
+    // CHECK: xegpu.create_nd_tdesc %[[ARG_0]][%[[OFFY]], %[[OFFX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
     gpu.return
@@ -43,7 +47,7 @@ gpu.module @test_round_robin_assignment {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
-      // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
       // CHECK-NOT: xegpu.load_nd
       %load =  xegpu.load_nd %tdesc
@@ -59,7 +63,7 @@ gpu.module @test_round_robin_assignment {
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
       // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-NOT : xegpu.store_nd
+      // CHECK-NOT: xegpu.store_nd
       %load = xegpu.load_nd %tdesc
         : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
         -> vector<256x128xf32>
@@ -74,7 +78,7 @@ gpu.module @test_round_robin_assignment {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       ->  !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
-    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>>
+    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.update_nd_offset
     %update = xegpu.update_nd_offset %tdesc, [0, 16]
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -86,12 +90,10 @@ gpu.module @test_round_robin_assignment {
   gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
     // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16>
     // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-NOT: xegpu.create_nd_tdesc
     // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
-    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
-    // CHECK-NOT: xegpu.create_nd_tdesc
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
     // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
-    // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-16: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
     // CHECK-NOT: xegpu.dpas
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
@@ -114,7 +116,7 @@ gpu.module @test_round_robin_assignment {
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
     // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
-    // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.prefetch_nd
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -171,10 +173,10 @@ gpu.module @test_round_robin_assignment {
     %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
     %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
-    //CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
+    // CHECK: scf.while ({{.*}}) : (vector<16xf32>, vector<16xf32>, i32) -> (vector<16xf32>, vector<16xf32>, i32)
     %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
       %4 = arith.cmpi slt, %arg3, %c10_i32 : i32
-      //CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32
+      // CHECK: scf.condition{{.*}} : vector<16xf32>, vector<16xf32>, i32
       scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
     } do {
     // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: vector<16xf32>, [[arg4:%.+]]: i32)
@@ -195,16 +197,16 @@ gpu.module @test_round_robin_assignment {
     %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     %3 = arith.cmpi eq, %0, %c10 : index
     // CHECK-LABEL: scf.if
-    //  CHECK-SAME: (vector<16xf32>, vector<16xf32>)
+    // CHECK-SAME: (vector<16xf32>, vector<16xf32>)
     %4 = scf.if %3 -> (vector<256xf32>) {
       %5 = xegpu.load_nd %1  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
       // CHECK-LABEL: scf.yield
-      //  CHECK-SAME: vector<16xf32>, vector<16xf32>
+      // CHECK-SAME: vector<16xf32>, vector<16xf32>
       scf.yield %5 : vector<256xf32>
     } else {
       %5 = xegpu.load_nd %2  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>> -> vector<256xf32>
       // CHECK-LABEL: scf.yield
-      //  CHECK-SAME: vector<16xf32>, vector<16xf32>
+      // CHECK-SAME: vector<16xf32>, vector<16xf32>
       scf.yield %5 : vector<256xf32>
     } {layout_result_0 = #xegpu.layout<sg_layout = [8], sg_data = [16]>}
     xegpu.store_nd %4, %1  : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -220,16 +222,16 @@ gpu.module @test_round_robin_assignment {
 
     %0 = arith.cmpi eq, %id, %c10 : index
     // CHECK-LABEL: scf.if
-    //  CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
+    // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
     %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>) {
       %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
       // CHECK-LABEL: scf.yield
-      //  CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
+      // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
       scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     } else {
       %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
       // CHECK-LABEL: scf.yield
-      //  CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
+      // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
       scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     }
     xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
@@ -238,8 +240,8 @@ gpu.module @test_round_robin_assignment {
 
   gpu.func @convert_layout_optimal(%arg0: memref<32x64xf32>) {
     %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<32x64xf32> -> !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>>
-    //CHECK-2: xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
-    //CHECK-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
+    // CHECK-COUNT-2: xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<inst_data = [16, 16]>> -> vector<16x16xf32>
+    // CHECK-COUNT-2: xegpu.convert_layout {{.*}} <{input_layout = #xegpu.layout<inst_data = [16, 16]>, target_layout = #xegpu.layout<inst_data = [8, 16]>}> : vector<16x16xf32>
     %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<32x64xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>> -> vector<32x64xf32>
     %2 = xegpu.convert_layout %1 <{input_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [16, 16]>,
                                    target_layout = #xegpu.layout<sg_layout = [2, 2], sg_data = [16, 16], inst_data = [8, 16]>}> : vector<32x64xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 86a021b66949c..9eb59fdacdb6d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -14,6 +14,15 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: load_nd_tdesc_with_offset
   gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc {{%.*}} : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+    // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+    // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+    // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
+    // CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+    // CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
     // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
     // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
@@ -28,8 +37,15 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: store_nd_with_offset
   gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
-    // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc {{%.*}} : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
     // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+    // CHECK: %[[SGID2:.*]] = gpu.subgroup_id : index
+    // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.store_nd
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -44,8 +60,11 @@ gpu.module @test_distribution {
   // CHECK-LABEL: prefetch_nd_tdesc_with_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
     // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}]
-    // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.prefetch_nd
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -59,12 +78,18 @@ gpu.module @test_distribution {
   gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
     // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16>
     // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-NOT: xegpu.create_nd_tdesc
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-SAME-COUNT-4: -> vector<16x16xf16>
     // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16>
-    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
-    // CHECK-NOT: xegpu.create_nd_tdesc
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    // CHECK: %[[SGID2:.*]] = gpu.subgroup_id : index
+    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    // CHECK-SAME-COUNT-4: -> vector<16x16xf16>
     // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
-    // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-16: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
     // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
     // CHECK-NOT: xegpu.dpas
     %tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16>
@@ -92,6 +117,11 @@ gpu.module @test_distribution {
     %load =  xegpu.load_nd %tdesc[0, 0]
       : !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
       -> vector<256x64xf32>
+    // CHECK-COUNT-2: xegpu.create_nd_tdesc {{%.*}} : memref<256x64xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<16x64xf32>
+    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-COUNT-2: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
+    // CHECK-SAME-COUNT-2: : !xegpu.tensor_desc<16x64xf32> -> vector<16x64xf32>
     // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32>
     // CHECK-NOT: vector.multi_reduction
     %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
@@ -102,23 +132,24 @@ gpu.module @test_distribution {
   gpu.func @non_splat_constant() {
     // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]]
-    // CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]]
-    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]]
-    // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]]
-    // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
+    // CHECK-DAG: %[[REMU1:.*]] = index.remu %[[SGID]], %[[C1:.*]]
+    // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C1:.*]]
+    // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[DIVU]], %[[C8:.*]]
+    // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU2]], %[[C2:.*]]
+    // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[MUL]], %[[C32:.*]]
+    // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[REMU1]], %[[C1:.*]]
     // CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index
-    // CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
-    // CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
-    // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[REMU5:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
+    // CHECK-DAG: %[[REMU6:.*]] = index.remu %[[REMU1]], %[[C1:.*]]
+    // CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
     // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index
-    // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
     // CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index
     // CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
     // CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
-    // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU5]], %[[C16:.*]] : index
     // CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index
-    // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
+    // CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU6]], %[[C0:.*]] : index
     // CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index
     // CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex>
     // CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index a744cf7c60999..5b67b79fab8f3 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -1,8 +1,5 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
-//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
-//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
-//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
 gpu.module @test_distribution {
   // CHECK-LABEL: create_nd_tdesc_no_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -29,9 +26,20 @@ gpu.module @test_distribution {
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
     //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
-    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
-    //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
+    //CHECK: [[C4:%.+]] = arith.constant 4 : index
+    //CHECK: [[SGIDX:%.+]] = index.remu [[SGID]], [[C4]]
+    //CHECK: [[SGIDY_TMP:%.+]] = index.divu [[SGID]], [[C4]]
+    //CHECK: [[C8:%.+]] = arith.constant 8 : index
+    //CHECK: [[SGIDY:%.+]] = index.remu [[SGIDY_TMP]], [[C8]]
+    //CHECK: [[C32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[SGIDY]], [[C32]]
+    //CHECK: [[C32_0:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALX:%.+]] = index.mul [[SGIDX]], [[C32_0]]
+    //CHECK: [[C256:%.+]] = arith.constant 256 : index
+    //CHECK: [[OFFY:%.+]] = index.remu [[LOCALY]], [[C256]]
+    //CHECK: [[C128:%.+]] = arith.constant 128 : index
+    //CHECK: [[OFFX:%.+]] = index.remu [[LOCALX]], [[C128]]
+    //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[[[OFFY]], [[OFFX]]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc[0, 0]
@@ -44,8 +52,20 @@ gpu.module @test_distribution {
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
     //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
-    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+    //CHECK: [[C4:%.+]] = arith.constant 4 : index
+    //CHECK: [[SGIDX:%.+]] = index.remu [[SGID]], [[C4]]
+    //CHECK: [[SGIDY_TMP:%.+]] = index.divu [[SGID]], [[C4]]
+    //CHECK: [[C8:%.+]] = arith.constant 8 : index
+    //CHECK: [[SGIDY:%.+]] = index.remu [[SGIDY_TMP]], [[C8]]
+    //CHECK: [[C32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[SGIDY]], [[C32]]
+    //CHECK: [[C32_0:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALX:%.+]] = index.mul [[SGIDX]], [[C32_0]]
+    //CHECK: [[C256:%.+]] = arith.constant 256 : index
+    //CHECK: [[OFFY:%.+]] = index.remu [[LOCALY]], [[C256]]
+    //CHECK: [[C128:%.+]] = arith.constant 128 : index
+    //CHECK: [[OFFX:%.+]] = index.remu [[LOCALX]], [[C128]]
+    //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[[[OFFY]], [[OFFX]]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
     //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]  : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -61,9 +81,20 @@ gpu.module @test_distribution {
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
     //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
-    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
-    //CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    //CHECK: [[C4:%.+]] = arith.constant 4 : index
+    //CHECK: [[SGIDX:%.+]] = index.remu [[SGID]], [[C4]]
+    //CHECK: [[SGIDY_TMP:%.+]] = index.divu [[SGID]], [[C4]]
+    //CHECK: [[C8:%.+]] = arith.constant 8 : index
+    //CHECK: [[SGIDY:%.+]] = index.remu [[SGIDY_TMP]], [[C8]]
+    //CHECK: [[C32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[SGIDY]], [[C32]]
+    //CHECK: [[C32_0:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALX:%.+]] = index.mul [[SGIDX]], [[C32_0]]
+    //CHECK: [[C256:%.+]] = arith.constant 256 : index
+    //CHECK: [[OFFY:%.+]] = index.remu [[LOCALY]], [[C256]]
+    //CHECK: [[C128:%.+]] = arith.constant 128 : index
+    //CHECK: [[OFFX:%.+]] = index.remu [[LOCALX]], [[C128]]
+    //CHECK: xegpu.prefetch_nd %{{.*}}[[[OFFY]], [[OFFX]]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %cst0 = arith.constant 0 : index
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -319,17 +350,15 @@ gpu.module @test_distribution {
   gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
     //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
     //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[c2:%.+]] = arith.constant 2 : index
     //CHECK: [[c4:%.+]] = arith.constant 4 : index
-    //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
-    //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
-    //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+    //CHECK: [[sgidx:%.+]] = index.remu [[sgid]], [[c4]]
+    //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgid]], [[c4]]
+    //CHECK: [[c2:%.+]] = arith.constant 2 : index
+    //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c2]]
     //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
-    //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
-    //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
-    //CHECK: [[c0:%.+]] = arith.constant 0 : index
-    //CHECK: [[c0_1:%.+]] = arith.constant 0 : index
+    //CHECK: [[l_off_y:%.+]] = index.mul [[sgidy]], [[c32]]
+    //CHECK: [[c32_0:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_x:%.+]] = index.mul [[sgidx]], [[c32_0]]
     //CHECK: [[c64:%.+]] = arith.constant 64 : index
     //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
     //CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -346,17 +375,15 @@ gpu.module @test_distribution {
     //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32>
     //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
     //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[c2:%.+]] = arith.constant 2 : index
     //CHECK: [[c4:%.+]] = arith.constant 4 : index
-    //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
-    //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
-    //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+    //CHECK: [[sgidx:%.+]] = index.remu [[sgid]], [[c4]]
+    //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgid]], [[c4]]
+    //CHECK: [[c2:%.+]] = arith.constant 2 : index
+    //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c2]]
     //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
-    //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
-    //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
-    //CHECK: [[c0:%.+]] = arith.constant 0 : index
-    //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
+    //CHECK: [[l_off_y:%.+]] = index.mul [[sgidy]], [[c32]]
+    //CHECK: [[c32_0:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_x:%.+]] = index.mul [[sgidx]], [[c32_0]]
     //CHECK: [[c64:%.+]] = arith.constant 64 : index
     //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
     //CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -411,14 +438,17 @@ gpu.module @test_distribution {
   // CHECK-LABEL: vector_step_op
   gpu.func @vector_step_op_slice_attr() {
     //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK-DAG: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
-    //CHECK-DAG: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK-DAG: [[LY:%.+]] = index.mul [[IDY]], [[c32]]
-    //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
-    //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK-DAG: [[MODY:%.+]] = index.remu [[LY]], [[c128]]
-    //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<32xindex>
-    //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+    //CHECK: [[c8:%.+]] = arith.constant 8 : index
+    //CHECK: [[sgidx:%.+]] = index.remu [[sgId]], [[c8]]
+    //CHECK: [[sgidy_tmp:%.+]] = index.divu [[sgId]], [[c8]]
+    //CHECK: [[c4:%.+]] = arith.constant 4 : index
+    //CHECK: [[sgidy:%.+]] = index.remu [[sgidy_tmp]], [[c4]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LY:%.+]] = index.mul [[sgidy]], [[c32]]
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[LY]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
     //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
     %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
     gpu.return
@@ -426,14 +456,14 @@ gpu.module @test_distribution {
 
   gpu.func @vector_step_op_layout_attr() {
     //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
-    //CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
-    //CHECK-DAG: [[c8:%.+]] = arith.constant 8 : index
-    //CHECK-DAG: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
-    //CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
-    //CHECK-DAG: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK-DAG: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
-    //CHECK-DAG: [[BASE:%.+]] = vector.step : vector<8xindex>
-    //CHECK-DAG: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+    //CHECK: [[c16:%.+]] = arith.constant 16 : index
+    //CHECK: [[sgidx:%.+]] = index.remu [[sgId]], [[c16]]
+    //CHECK: [[c8:%.+]] = arith.constant 8 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[sgidx]], [[c8]]
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[LOCALY]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
     //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
     %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
     gpu.return
@@ -480,10 +510,11 @@ gpu.module @test_distribution {
   gpu.func @non_splat_constant_2D() {
     // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: affine.apply #map4()[%[[SGID]]]
-    // CHECK-DAG: affine.apply #map5()[%[[SGID]]]
-    // CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]]
-    // CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]]
+    // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %{{.*}}
+    // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %{{.*}}
+    // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %{{.*}}
+    // CHECK-DAG: %[[IDY:.*]] = index.remu %[[SGIDY]], %{{.*}}
+    // CHECK-DAG: %[[IDX:.*]] = index.remu %[[SGIDX]], %{{.*}}
     // CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index
     // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[STRIDECOL]] : index
     // CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index
@@ -496,20 +527,19 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: non_splat_constant_2D_non_unit_dim
   gpu.func @non_splat_constant_2D_non_unit_dim() {
-    // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex>
+    // CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{\[}}{{\[}}0, 16{{\]}}, {{\[}}8, 24{{\]}}{{\]}}> : vector<2x2xindex>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]]
-    // CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]]
-    // CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]]
-    // CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
-    // CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]]
+    // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %{{.*}}
+    // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %{{.*}}
+    // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %{{.*}}
+    // CHECK-DAG: %[[MULY:.*]] = index.mul %[[SGIDY]], %[[C2:.*]]
+    // CHECK-DAG: %[[MULX:.*]] = index.mul %[[SGIDX]], %{{.*}}
     // CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]]
-    // CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index
-    // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]]
-    // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index
+    // CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %{{.*}}
+    // CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %{{.*}} : index
     // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[MUL5]] : index
     // CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index
-    // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi  %[[ADD]], %[[MUL6]] : index
+    // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[MUL6]] : index
     // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<2x2xindex>
     // CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex>
     %cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2]>} dense<[
@@ -529,13 +559,14 @@ gpu.module @test_distribution {
   gpu.func @non_splat_constant() {
     // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
     // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C32:.*]]
-    // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index
+    // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %{{.*}}
+    // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[REMU]], %{{.*}}
+    // CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU2]], %[[C16:.*]] : index
     // CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[MUL]] : index
     // CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1xindex>
     // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
     %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496]> : vector<32xindex>
-    // CHECK: arith.constant dense<{{\[}}[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]{{\]}}> : vector<1x16xindex>
+    // CHECK: arith.constant dense<{{\[}}{{\[}}0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15{{\]}}{{\]}}> : vector<1x16xindex>
     %cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 16]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
     gpu.return
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index e83229e3a3995..d47f120129261 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -1,47 +1,35 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
-//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
-//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
 gpu.module @test_1_1_assignment {
   // CHECK-LABEL: create_nd_tdesc
-  // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
-    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
-    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
-    //CHECK: [[C32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
-    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
-    //CHECK: [[C0:%.+]] = arith.constant 0 : index
-    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
-    //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[Y:%.+]] = index.remu [[LY]], [[C256]]
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[X:%.+]] = index.remu [[LX]], [[C128]]
-    //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[REMUX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[REMUY:.*]] = index.remu %[[DIVU]], %[[C8:.*]]
+    // CHECK-DAG: %[[MULY:.*]] = index.mul %[[REMUY]], %[[C32:.*]]
+    // CHECK-DAG: %[[MULX:.*]] = index.mul %[[REMUX]], %[[C32:.*]]
+    // CHECK-DAG: %[[MODY:.*]] = index.remu %[[MULY]], %[[C256:.*]]
+    // CHECK-DAG: %[[MODX:.*]] = index.remu %[[MULX]], %[[C128:.*]]
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[MODY]], %[[MODX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }
 
   // CHECK-LABEL: create_nd_tdesc_from_higher_rank_memref
-  // CHECK-SAME: [[ARG_0:%.*]]: memref<3x256x128xf32>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<3x256x128xf32>
   gpu.func @create_nd_tdesc_from_higher_rank_memref(%src: memref<3x256x128xf32>) {
-    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
-    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
-    //CHECK: [[C32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
-    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
-    //CHECK: [[C0:%.+]] = arith.constant 0 : index
-    //CHECK: [[C0_2:%.+]] = arith.constant 0 : index
-    //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[LY]], [[C256]]
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODX:%.+]] = index.remu [[LX]], [[C128]]
-    //CHECK: [[C0_3:%.+]] = arith.constant 0 : index
-    //CHECK: [[C0_4:%.+]] = arith.constant 0 : index
-    //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][1, [[MODY]], [[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+    // CHECK-DAG: %[[REMUX:.*]] = index.remu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[DIVU:.*]] = index.divu %[[SGID]], %[[C4:.*]]
+    // CHECK-DAG: %[[REMUY:.*]] = index.remu %[[DIVU]], %[[C8:.*]]
+    // CHECK-DAG: %[[MULY:.*]] = index.mul %[[REMUY]], %[[C32:.*]]
+    // CHECK-DAG: %[[MULX:.*]] = index.mul %[[REMUX]], %[[C32:.*]]
+    // CHECK-DAG: %[[MODY:.*]] = index.remu %[[MULY]], %[[C256:.*]]
+    // CHECK-DAG: %[[MODX:.*]] = index.remu %[[MULX]], %[[C128:.*]]
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][1, %[[MODY]], %[[MODX]]] : memref<3x256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[1, 0, 0] : memref<3x256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
@@ -50,9 +38,9 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: load_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
     // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-SAME: -> vector<32x32xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
@@ -66,12 +54,12 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: store_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @store_nd(%src: memref<256x128xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
     // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
+    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-SAME: -> vector<32x32xf32>
-    // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
+    // CHECK-DAG: xegpu.store_nd %[[LOAD]], %[[TDESC]]
     // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -81,25 +69,25 @@ gpu.module @test_1_1_assignment {
     xegpu.store_nd %load, %tdesc
       : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
-}
+  }
 
-// CHECK-LABEL: update_nd
-// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
-gpu.func @update_nd(%src: memref<256x128xf32>){
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
-  // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
-  // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-    -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-  %update = xegpu.update_nd_offset %tdesc, [0, 16]
-    : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-  gpu.return
-}
+  // CHECK-LABEL: update_nd
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @update_nd(%src: memref<256x128xf32>){
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-DAG: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
+    // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    %update = xegpu.update_nd_offset %tdesc, [0, 16]
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
 
-// CHECK-LABEL: dpas
-gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
-    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
+  // CHECK-LABEL: dpas
+  gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+    // CHECK-DAG: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
       -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a
@@ -116,10 +104,9 @@ gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     gpu.return
   }
 
-
-// CHECK-LABEL: dpas_no_sg_data
-gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
-    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+  // CHECK-LABEL: dpas_no_sg_data
+  gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
+    // CHECK-DAG: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
       -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>>
@@ -143,9 +130,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   // CHECK-LABEL: prefetch_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
     // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK: xegpu.prefetch_nd %[[TDESC]]
+    // CHECK-DAG: xegpu.prefetch_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -166,13 +153,13 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   // CHECK-LABEL: broadcast_dim1
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
   gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
+    // CHECK-DAG: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32>
       -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
       -> vector<256x1xf32>
-    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
-    // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
     %broadcast = vector.broadcast %load
       {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
       : vector<256x1xf32> to vector<256x32xf32>
@@ -182,13 +169,13 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   // CHECK-LABEL: broadcast_dim0
   // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
   gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
+    // CHECK-DAG: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32>
       -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<1x128xf32>
-    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-    // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
     %broadcast = vector.broadcast %load
       {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
       : vector<1x128xf32> to vector<32x128xf32>
@@ -196,9 +183,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   }
 
   gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
-    //CHECK: [[c0:%.+]] = arith.constant 0 : index
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
+    // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+    // CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+    // CHECK-DAG: %[[C1024:.*]] = arith.constant 1024 : index
     %c0 = arith.constant 0 : index
     %c128 = arith.constant 128 : index
     %c1024 = arith.constant 1024 : index
@@ -211,15 +198,15 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
     %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
 
-    //      CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
-    // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
+    // CHECK: %[[SCF:.*]]:3 = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C1024]] step %[[C128]]
+    // CHECK-SAME: iter_args(%[[ARG4:.*]] = {{.*}}, %[[ARG5:.*]] = {{.*}}, %[[ARG6:.*]] = {{.*}}) ->
     // CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
-    //      CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
-    //      CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
-    //      CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
-    //      CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
-    //      CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
-    //      CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
+    // CHECK: %[[A:.*]] = xegpu.load_nd %[[ARG4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+    // CHECK: %[[B:.*]] = xegpu.load_nd %[[ARG5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+    // CHECK: %[[C:.*]] = xegpu.dpas %[[A]], %[[B]], %[[ARG6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+    // CHECK: %[[AT:.*]] = xegpu.update_nd_offset %[[ARG4]], [%[[C0]], %[[C128]]] : !xegpu.tensor_desc<16x128xf16>
+    // CHECK: %[[BT:.*]] = xegpu.update_nd_offset %[[ARG5]], [%[[C128]], %[[C0]]] : !xegpu.tensor_desc<128x16xf16>
+    // CHECK: scf.yield %[[AT]], %[[BT]], %[[C]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
     %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3)
         -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
             !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
@@ -252,7 +239,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
       // CHECK: scf.condition{{.*}} : vector<16xf32>, i32
       scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
     } do {
-    // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: i32)
+    // CHECK: (%[[ARG2:.*]]: vector<16xf32>, %[[ARG3:.*]]: i32)
     ^bb0(%arg2: vector<256xf32>, %arg3: i32):
       xegpu.store_nd %arg2, %2  : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
       %4 = arith.addi %arg3, %c1_i32 : i32
@@ -344,9 +331,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     %cond4 = arith.cmpi slt, %sg_id, %c31 : index
     %cond5 = arith.andi %cond3, %cond4 : i1
     scf.if %cond5 {
-      // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
-      // CHECK: %[[C2:.*]] = arith.constant 2 : index
-      // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
+        // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+        // CHECK: %[[C2:.*]] = arith.constant 2 : index
+        // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
       %tdesc = xegpu.create_nd_tdesc %src2[0, 0] : memref<128x64xf32>
         -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [32, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
       %load =  xegpu.load_nd %tdesc

>From fb29a2feabe16933190b9c3ea41dcb438fdcb102 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 22 Oct 2025 03:42:34 +0000
Subject: [PATCH 4/5] Add 1:N test case

---
 .../Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 9eb59fdacdb6d..63d1c3e3abb11 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -156,4 +156,18 @@ gpu.module @test_distribution {
     %cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
     gpu.return
   }
+
+  // CHECK-LABEL: vector_transpose
+  gpu.func @vector_transpose(%src: memref<256x128xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1], order =[0, 1]>>
+    %load = xegpu.load_nd %tdesc[0, 0]
+        : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1], order =[0, 1]>>
+        -> vector<256x128xf32>
+    // CHECK-COUNT-2: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<32x16xf32> to vector<16x32xf32>
+    // CHECK-NOT: vector.transpose
+    %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
+      gpu.return
+  }
 }
+

>From 205fdfd0266c9b71d7ca0532f0340177884ed7e1 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 22 Oct 2025 18:50:11 +0000
Subject: [PATCH 5/5] Clean up tests

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  1 -
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 68 +++++++++----------
 .../Dialect/XeGPU/xegpu-attr-interface.mlir   |  2 -
 .../XeGPU/xegpu-wg-to-sg-elemwise.mlir        |  6 +-
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 36 +++-------
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    | 57 +++-------------
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 64 +++++------------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 30 ++++----
 8 files changed, 83 insertions(+), 181 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index d3383658fc31d..dfd4093905875 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -288,7 +288,6 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
         llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
   }
 
-  // Validate order
   if (order.size() != sgLayoutInt.size()) {
     return failure();
   }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a144b5e43615e..88d3ed743628e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1235,60 +1235,63 @@ struct WgToSgVectorTransposeOp
     if (!layout || !layout.isForWorkgroup())
       return failure();
 
-    // Get the source layout for validation
     xegpu::DistributeLayoutAttr sourceLayout =
         xegpu::getDistributeLayoutAttr(op.getVector());
     if (!sourceLayout || !sourceLayout.isForWorkgroup())
       return failure();
 
-    // Validate that result layout is transpose of source layout
     SmallVector<int64_t> sourceSgLayout =
         sourceLayout.getEffectiveSgLayoutAsInt();
     SmallVector<int64_t> sourceSgData = sourceLayout.getEffectiveSgDataAsInt();
     SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
     SmallVector<int64_t> resultSgData = layout.getEffectiveSgDataAsInt();
+    DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
+    DenseI32ArrayAttr resultOrder = layout.getOrder();
 
-    ArrayRef<int64_t> permutation = op.getPermutation();
-
-    // Check that sgLayout and sgData are properly transposed
-    if (sourceSgLayout.size() != resultSgLayout.size() ||
-        sourceSgData.size() != resultSgData.size() ||
-        sourceSgLayout.size() != permutation.size()) {
+    if (!sourceOrder || !resultOrder) {
       return rewriter.notifyMatchFailure(
-          op, "Source and result layouts must have same rank as permutation");
+          op, "Both source and result must have order attributes");
     }
 
-    // Validate sgLayout transpose
-    for (size_t i = 0; i < permutation.size(); ++i) {
-      int64_t srcDim = permutation[i];
-      if (srcDim < 0 || srcDim >= static_cast<int64_t>(sourceSgLayout.size())) {
-        return rewriter.notifyMatchFailure(op, "Invalid permutation index");
-      }
-      if (resultSgLayout[i] != sourceSgLayout[srcDim]) {
-        return rewriter.notifyMatchFailure(
-            op, "Result sgLayout is not transpose of source sgLayout according "
-                "to permutation");
-      }
+    SmallVector<int64_t> sourceOrderVec = llvm::to_vector(
+        llvm::map_range(sourceOrder.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+    SmallVector<int64_t> resultOrderVec = llvm::to_vector(
+        llvm::map_range(resultOrder.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+
+    ArrayRef<int64_t> permutation = op.getPermutation();
+    size_t expectedSize = permutation.size();
+    if (sourceSgLayout.size() != expectedSize ||
+        sourceSgData.size() != expectedSize ||
+        resultSgLayout.size() != expectedSize ||
+        resultSgData.size() != expectedSize ||
+        sourceOrderVec.size() != expectedSize ||
+        resultOrderVec.size() != expectedSize) {
+      return rewriter.notifyMatchFailure(
+          op, "All layouts and permutation must have the same rank");
     }
 
-    // Validate sgData transpose
+    // Check that sgLayout, sgData & order are properly transposed for operand
+    // and result
     for (size_t i = 0; i < permutation.size(); ++i) {
       int64_t srcDim = permutation[i];
-      if (resultSgData[i] != sourceSgData[srcDim]) {
+      if (resultSgLayout[i] != sourceSgLayout[srcDim] ||
+          resultSgData[i] != sourceSgData[srcDim] ||
+          resultOrderVec[i] != sourceOrderVec[srcDim]) {
         return rewriter.notifyMatchFailure(
-            op, "Result sgData is not transpose of source sgData according to "
-                "permutation");
+            op, "Result layout is not a valid transpose of source layout "
+                "according to permutation");
       }
     }
 
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
-
     SmallVector<Value> newTransposeOps;
     for (auto src : adaptor.getVector()) {
       auto newTranspose = vector::TransposeOp::create(
-          rewriter, op.getLoc(), newResultType, src, op.getPermutation());
+          rewriter, op.getLoc(), newResultType, src, permutation);
       if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
           !layout.getEffectiveInstDataAsInt().empty())
         xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
@@ -1444,7 +1447,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
       });
 
   target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp,
-                               vector::TransposeOp, vector::BroadcastOp>(
+                               vector::TransposeOp, vector::BroadcastOp,
+                               vector::MultiDimReductionOp>(
       [=](Operation *op) -> bool {
         // Check for either a SliceAttr or LayoutAttr on the result.
         auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
@@ -1463,16 +1467,6 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(layout);
       });
 
-  target.addDynamicallyLegalOp<vector::BroadcastOp>(
-      [=](vector::BroadcastOp op) -> bool {
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
-      });
-
-  target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
-      [=](vector::MultiDimReductionOp op) -> bool {
-        return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
-      });
-
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
index 7d785437948ea..02c5f71d5c83d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -10,7 +10,6 @@ gpu.module @test {
     // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
     // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
     // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
-    // CHECK-DAG: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ADD]] : vector<32xindex> to vector<128xindex>
     %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
     gpu.return %step : vector<128xindex>
   }
@@ -25,7 +24,6 @@ gpu.module @test {
     // CHECK-DAG: %[[BASE:.*]] = vector.step : vector<32xindex>
     // CHECK-DAG: %[[CAST:.*]] = vector.broadcast %[[MOD]] : index to vector<32xindex>
     // CHECK-DAG: %[[ADD:.*]] = arith.addi %[[BASE]], %[[CAST]] : vector<32xindex>
-    // CHECK-DAG: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[ADD]] : vector<32xindex> to vector<128xindex>
     %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
     gpu.return %0 : vector<128xindex>
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
index 09df1e4da43e2..9580769d37313 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-elemwise.mlir
@@ -166,14 +166,12 @@ gpu.module @test_elementwise_ops {
     %load_b = xegpu.load_nd %tdesc_b
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
       -> vector<24x32xf32>
-    // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+    // CHECK-COUNT-12: arith.negf {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
     // CHECK-NOT: arith.negf
     %negf = arith.negf %load_a
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
       : vector<24x32xf32>
-    // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-12: : vector<2x2xf32>
+    // CHECK-COUNT-12: math.powf {{.*}}, {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>} : vector<2x2xf32>
     // CHECK-NOT: math.powf
     %powf = math.powf %load_a, %load_b
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 965eb08bcb506..01134d8eaabec 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -4,14 +4,7 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
-      // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
-      // CHECK: %[[C4:.*]] = arith.constant 4 : index
-      // CHECK: %[[IDX:.*]] = index.remu %[[SGID]], %[[C4]]
-      // CHECK: %[[IDY_DIV:.*]] = index.divu %[[SGID]], %[[C4]]
-      // CHECK: %[[C8:.*]] = arith.constant 8 : index
-      // CHECK: %[[IDY:.*]] = index.remu %[[IDY_DIV]], %[[C8]]
-      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
-      // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-NOT: xegpu.create_nd_tdesc
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -46,9 +39,7 @@ gpu.module @test_round_robin_assignment {
   gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
-      // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+      // CHECK-COUNT-4: xegpu.load_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32>
       // CHECK-NOT: xegpu.load_nd
       %load =  xegpu.load_nd %tdesc
         : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -61,8 +52,7 @@ gpu.module @test_round_robin_assignment {
   gpu.func @store_nd(%src: memref<256x128xf32>) {
       %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
         -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-      // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
-      // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+      // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}} : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
       // CHECK-NOT: xegpu.store_nd
       %load = xegpu.load_nd %tdesc
         : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -77,8 +67,7 @@ gpu.module @test_round_robin_assignment {
   gpu.func @update_nd(%src: memref<256x128xf32>){
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       ->  !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
-    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.update_nd_offset
     %update = xegpu.update_nd_offset %tdesc, [0, 16]
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -88,13 +77,9 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: dpas
   // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
   gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
-    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16>
-    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
-    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-    // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
-    // CHECK-SAME-COUNT-16: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
     // CHECK-NOT: xegpu.dpas
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf16>
       -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -115,8 +100,7 @@ gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: prefetch_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
-    // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
-    // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}} : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.prefetch_nd
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -133,9 +117,7 @@ gpu.module @test_round_robin_assignment {
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<128x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
       -> vector<128x1xf32>
-    // CHECK-COUNT-2: vector.broadcast {{.*}}
-    // CHECK-SAME-COUNT-2: {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-2: : vector<16x1xf32> to vector<16x32xf32>
+    // CHECK-COUNT-2: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>} : vector<16x1xf32> to vector<16x32xf32>
     // CHECK-NOT: vector.broadcast
     %broadcast = vector.broadcast %load
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [16, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 63d1c3e3abb11..18a21146a3d75 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -14,22 +14,11 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: load_nd_tdesc_with_offset
   gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
-    // CHECK-COUNT-4: xegpu.create_nd_tdesc {{%.*}} : memref<256x128xf32>
-    // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-    // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
-    // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
-    // CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
-    // CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
-    // CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
-    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
-    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
+    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf32>
     // CHECK-NOT: xegpu.load_nd
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
-    %load =  xegpu.load_nd %tdesc[0, 0]
+    %load = xegpu.load_nd %tdesc[0, 0]
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<256x128xf32>
     gpu.return
@@ -37,15 +26,7 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: store_nd_with_offset
   gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
-    // CHECK-COUNT-4: xegpu.create_nd_tdesc {{%.*}} : memref<256x128xf32>
-    // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
-    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
-    // CHECK: %[[SGID2:.*]] = gpu.subgroup_id : index
-    // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
-    // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.store_nd
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -58,13 +39,8 @@ gpu.module @test_distribution {
   }
 
   // CHECK-LABEL: prefetch_nd_tdesc_with_offset
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
-    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32>
-    // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}]
-    // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.prefetch_nd
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -76,21 +52,11 @@ gpu.module @test_distribution {
   // CHECK-LABEL: dpas
   // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16>, %[[ARG_1:.*]]: memref<128x256xf16>)
   gpu.func @dpas(%a: memref<256x128xf16>, %b: memref<128x256xf16>) {
-    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16>
-    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
-    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-SAME-COUNT-4: -> vector<16x16xf16>
-    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16>
-    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-    // CHECK: %[[SGID2:.*]] = gpu.subgroup_id : index
-    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
-    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
-    // CHECK-SAME-COUNT-4: -> vector<16x16xf16>
-    // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
-    // CHECK-SAME-COUNT-16: {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-16: : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf16> -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
+    // CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xf16>
+    // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
     // CHECK-NOT: xegpu.dpas
     %tdesc_a = xegpu.create_nd_tdesc %a : memref<256x128xf16>
       -> !xegpu.tensor_desc<256x128xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -117,11 +83,6 @@ gpu.module @test_distribution {
     %load =  xegpu.load_nd %tdesc[0, 0]
       : !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
       -> vector<256x64xf32>
-    // CHECK-COUNT-2: xegpu.create_nd_tdesc {{%.*}} : memref<256x64xf32>
-    // CHECK-SAME: -> !xegpu.tensor_desc<16x64xf32>
-    // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
-    // CHECK-COUNT-2: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
-    // CHECK-SAME-COUNT-2: : !xegpu.tensor_desc<16x64xf32> -> vector<16x64xf32>
     // CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32>
     // CHECK-NOT: vector.multi_reduction
     %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 5b67b79fab8f3..ad02f0c9cb4e7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -23,24 +23,23 @@ gpu.module @test_distribution {
   }
 
   // CHECK-LABEL: load_nd_tdesc_with_offset
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
-    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[C4:%.+]] = arith.constant 4 : index
-    //CHECK: [[SGIDX:%.+]] = index.remu [[SGID]], [[C4]]
-    //CHECK: [[SGIDY_TMP:%.+]] = index.divu [[SGID]], [[C4]]
-    //CHECK: [[C8:%.+]] = arith.constant 8 : index
-    //CHECK: [[SGIDY:%.+]] = index.remu [[SGIDY_TMP]], [[C8]]
-    //CHECK: [[C32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[SGIDY]], [[C32]]
-    //CHECK: [[C32_0:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALX:%.+]] = index.mul [[SGIDX]], [[C32_0]]
-    //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[OFFY:%.+]] = index.remu [[LOCALY]], [[C256]]
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[OFFX:%.+]] = index.remu [[LOCALX]], [[C128]]
-    //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[[[OFFY]], [[OFFX]]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
-    %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
+    //CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
+    //CHECK: %[[C4:.*]] = arith.constant 4 : index
+    //CHECK: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4]]
+    //CHECK: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4]]
+    //CHECK: %[[C8:.*]] = arith.constant 8 : index
+    //CHECK: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8]]
+    //CHECK: %[[C32:.*]] = arith.constant 32 : index
+    //CHECK: %[[L_OFF_Y:.*]] = index.mul %[[SGIDY]], %[[C32]]
+    //CHECK: %[[L_OFF_X:.*]] = index.mul %[[SGIDX]], %[[C32]]
+    //CHECK: %[[C256:.*]] = arith.constant 256 : index
+    //CHECK: %[[OFF_Y:.*]] = index.remu %[[L_OFF_Y]], %[[C256]]
+    //CHECK: %[[C128:.*]] = arith.constant 128 : index
+    //CHECK: %[[OFF_X:.*]] = index.remu %[[L_OFF_X]], %[[C128]]
+    //CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %{{.*}}[%[[OFF_Y]], %[[OFF_X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    //CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]][{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc[0, 0]
       : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -51,21 +50,6 @@ gpu.module @test_distribution {
   // CHECK-LABEL: store_nd_with_offsets
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
-    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[C4:%.+]] = arith.constant 4 : index
-    //CHECK: [[SGIDX:%.+]] = index.remu [[SGID]], [[C4]]
-    //CHECK: [[SGIDY_TMP:%.+]] = index.divu [[SGID]], [[C4]]
-    //CHECK: [[C8:%.+]] = arith.constant 8 : index
-    //CHECK: [[SGIDY:%.+]] = index.remu [[SGIDY_TMP]], [[C8]]
-    //CHECK: [[C32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[SGIDY]], [[C32]]
-    //CHECK: [[C32_0:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALX:%.+]] = index.mul [[SGIDX]], [[C32_0]]
-    //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[OFFY:%.+]] = index.remu [[LOCALY]], [[C256]]
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[OFFX:%.+]] = index.remu [[LOCALX]], [[C128]]
-    //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[[[OFFY]], [[OFFX]]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
     //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]  : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -80,21 +64,7 @@ gpu.module @test_distribution {
   // CHECK-LABEL: prefetch_nd_tdesc_with_offset
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
-    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[C4:%.+]] = arith.constant 4 : index
-    //CHECK: [[SGIDX:%.+]] = index.remu [[SGID]], [[C4]]
-    //CHECK: [[SGIDY_TMP:%.+]] = index.divu [[SGID]], [[C4]]
-    //CHECK: [[C8:%.+]] = arith.constant 8 : index
-    //CHECK: [[SGIDY:%.+]] = index.remu [[SGIDY_TMP]], [[C8]]
-    //CHECK: [[C32:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALY:%.+]] = index.mul [[SGIDY]], [[C32]]
-    //CHECK: [[C32_0:%.+]] = arith.constant 32 : index
-    //CHECK: [[LOCALX:%.+]] = index.mul [[SGIDX]], [[C32_0]]
-    //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[OFFY:%.+]] = index.remu [[LOCALY]], [[C256]]
-    //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[OFFX:%.+]] = index.remu [[LOCALX]], [[C128]]
-    //CHECK: xegpu.prefetch_nd %{{.*}}[[[OFFY]], [[OFFX]]] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    //CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %cst0 = arith.constant 0 : index
     %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d47f120129261..5ce3d1d0fb5d6 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -38,9 +38,9 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: load_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
-    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
     // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-SAME: -> vector<32x32xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
@@ -54,12 +54,12 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: store_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @store_nd(%src: memref<256x128xf32>) {
-    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
     // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
+    // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     // CHECK-SAME: -> vector<32x32xf32>
-    // CHECK-DAG: xegpu.store_nd %[[LOAD]], %[[TDESC]]
+    // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
     // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -74,9 +74,9 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: update_nd
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @update_nd(%src: memref<256x128xf32>){
-    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
     // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-DAG: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
+    // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
     // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -87,7 +87,6 @@ gpu.module @test_1_1_assignment {
 
   // CHECK-LABEL: dpas
   gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
-    // CHECK-DAG: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
       -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a
@@ -98,6 +97,7 @@ gpu.module @test_1_1_assignment {
     %load_b =  xegpu.load_nd %tdesc_b
       : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
       -> vector<128x128xf16>
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
     %dpas = xegpu.dpas %load_a, %load_b
       {layout_result_0 =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
       : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
@@ -106,7 +106,6 @@ gpu.module @test_1_1_assignment {
 
   // CHECK-LABEL: dpas_no_sg_data
   gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
-    // CHECK-DAG: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
       -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>>
@@ -121,6 +120,7 @@ gpu.module @test_1_1_assignment {
       : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
       order = [1, 0]>>
       -> vector<128x128xf16>
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 =  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
     %dpas = xegpu.dpas %load_a, %load_b
       {layout_result_0 =  #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
       : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
@@ -130,9 +130,9 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: prefetch_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
-    // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
     // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    // CHECK-DAG: xegpu.prefetch_nd %[[TDESC]]
+    // CHECK: xegpu.prefetch_nd %[[TDESC]]
     // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -153,13 +153,13 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: broadcast_dim1
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
   gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
-    // CHECK-DAG: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
-    // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x1xf32>
       -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
       -> vector<256x1xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
     %broadcast = vector.broadcast %load
       {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
       : vector<256x1xf32> to vector<256x32xf32>
@@ -169,13 +169,13 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: broadcast_dim0
   // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
   gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
-    // CHECK-DAG: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
-    // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x128xf32>
       -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
       -> vector<1x128xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
     %broadcast = vector.broadcast %load
       {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
       : vector<1x128xf32> to vector<32x128xf32>



More information about the Mlir-commits mailing list