[Mlir-commits] [mlir] [MLIR][XeGPU] Remove the transpose attribte from Gather/Scatter ops and Cleanup the documents (PR #145389)

Chao Chen llvmlistbot at llvm.org
Tue Jun 24 11:25:23 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/145389

>From 010f5079b089ca7260a83b42f471d757e1b4bf70 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 23 Jun 2025 18:54:53 +0000
Subject: [PATCH 1/6] drop transpose attribute from load_gather and
 store_scatter

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 34 ++++------
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  2 +
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 28 +-------
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 64 ++++++++++---------
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 11 +---
 mlir/test/Dialect/XeGPU/ops.mlir              | 36 ++++-------
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 45 ++++++-------
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   | 46 ++++++-------
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 56 ++++++++--------
 9 files changed, 137 insertions(+), 185 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e6c7efc47593f..ffc08e9b90b56 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -609,12 +609,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   let description = [{ It (aka. load) load data per each work-item. The output
     describes the data being loaded at the subgroup level, so its size is
     consistent with the number of work-items in a subgroup. When the chunk size
-    is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
-    to work-items, and dim-0 corresponding to the chunk size loaded by each work-item.
-    Specially, there is a transpose effect on the result (as compared to the TensorDesc)
-    due to the hardware implementation. Therefore, a transpose attribute is introduced
-    on purpose, making sure users are aware of this implicit transformation.
-
+    is larger than 2, the output vector is a 2D vector, with dim-0 correspoding
+    to work-items, and dim-1 corresponding to the chunk size loaded by each work-item.
     The mask operand masks out memory access so that it is safe to pass out-of-boundary
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
@@ -634,8 +630,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
 
   Example 2:
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose,
-                            l1_hint = #xegpu.cache_hint<cached>,
+    %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
@@ -643,20 +638,18 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   ```
   Example 3 (SIMT mode):
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose,
-                            l1_hint = #xegpu.cache_hint<cached>,
+    %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>,
             !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-            vector<16xi1> -> vector<8x1xf32>
+            vector<16xi1> -> vector<8xf32>
   ```
 
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
                        XeGPU_MaskType: $mask,
-                       OptionalAttr<UnitAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -714,19 +707,17 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
 
   Example 2:
   ```mlir
-    xegpu.store %0, %1, %2 {transpose,
-                                 l1_hint = #xegpu.cache_hint<uncached>,
-                                 l2_hint = #xegpu.cache_hint<write_back>,
-                                 l3_hint = #xegpu.cache_hint<write_through>}
+    xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+                            l2_hint = #xegpu.cache_hint<write_back>,
+                            l3_hint = #xegpu.cache_hint<write_through>}
           : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
   ```
   Example 3 (SIMT mode):
   ```mlir
-    xegpu.store %0, %1, %2 {transpose,
-                                 l1_hint = #xegpu.cache_hint<uncached>,
-                                 l2_hint = #xegpu.cache_hint<write_back>,
-                                 l3_hint = #xegpu.cache_hint<write_through>}
-          : vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
+    xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+                            l2_hint = #xegpu.cache_hint<write_back>,
+                            l3_hint = #xegpu.cache_hint<write_through>}
+          : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
             !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> vector<16xi1>
   ```
 
@@ -736,7 +727,6 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
     XeGPU_ValueType: $value,
     XeGPU_TensorDesc: $TensorDesc,
     XeGPU_MaskType: $mask,
-    OptionalAttr<UnitAttr>: $transpose,
     OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 772cf73649646..09311e6017d0c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -35,6 +35,8 @@ constexpr unsigned packedSizeInBitsForDefault =
     16; // Minimum packing size per register for DPAS A.
 constexpr unsigned packedSizeInBitsForDpasB =
     32; // Minimum packing size per register for DPAS B.
+constexpr unsigned packedSizeInBitsForGatherScatter =
+    32; // Minimum packing size per register for Gather and Scatter ops.
 } // namespace targetinfo
 } // namespace xegpu
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0afc502c026f7..51b06d7c5ce3f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -20,13 +20,6 @@
 namespace mlir {
 namespace xegpu {
 
-static void transpose(llvm::ArrayRef<int64_t> trans,
-                      SmallVector<int64_t> &shape) {
-  SmallVector<int64_t> old = shape;
-  for (size_t i = 0; i < trans.size(); i++)
-    shape[i] = old[trans[i]];
-}
-
 template <typename T>
 static std::string makeString(T array, bool breakline = false) {
   std::string buf;
@@ -75,8 +68,7 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
 }
 
 static LogicalResult
-isValidGatherScatterParams(Type maskTy, VectorType valueTy,
-                           TensorDescType tdescTy, UnitAttr transposeAttr,
+isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy,
                            function_ref<InFlightDiagnostic()> emitError) {
 
   if (!tdescTy.isScattered())
@@ -102,17 +94,9 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   if (valueTy.getRank() == 1 && valueTy.getNumElements() == chunkSize) {
     if (tdescTy.getLayoutAttr())
       return emitError() << "TensorDesc doesn't need LayoutAttr for SIMT code";
-    if (transposeAttr)
-      return emitError() << "doesn't need TransposeAttr for SIMT code";
     return success();
   }
 
-  if (tdescTy.getRank() == 2 && valueTy.getRank() == 2) {
-    if (!transposeAttr)
-      return emitError() << "rank-2 tensor has to be transposed.";
-    transpose({1, 0}, tdescShape);
-  }
-
   if (tdescShape != valueShape)
     return emitError() << "Value shape " << makeString(valueShape)
                        << " is neither a valid distribution for SIMT nor "
@@ -310,13 +294,9 @@ LogicalResult LoadNdOp::verify() {
 
   if (getTranspose()) {
     auto trans = getTranspose().value();
-
     // Make sure the transpose value is valid.
-    bool valid = llvm::all_of(
-        trans, [&](int t) { return t >= 0 && t < tdescTy.getRank(); });
-
-    if (valid)
-      transpose(trans, tdescShape);
+    if (llvm::all_of(trans, [&](size_t s) { return s < tdescShape.size(); }))
+      tdescShape = applyPermutation(tdescShape, trans);
     else
       mlir::emitWarning(getLoc()) << "Invalid transpose attr. It is ignored.";
   }
@@ -536,7 +516,6 @@ LogicalResult LoadGatherOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    getTransposeAttr(),
                                     [&]() { return emitOpError(); });
 }
 
@@ -558,7 +537,6 @@ LogicalResult StoreScatterOp::verify() {
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
   return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    getTransposeAttr(),
                                     [&]() { return emitOpError(); });
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index cc22d2bbd8c39..d713610ca9575 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -213,6 +213,31 @@ static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
                     LaneData({1, packingFactor}));
 }
 
+/// Helper to get the default layout for a vector type.
+static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
+  // Expecting a 1D or 2D vector.
+  assert((tdescTy.getRank() == 1 || tdescTy.getRank() == 2) &&
+         "Expected 1D or 2D TensorDesc.");
+  // Expecting int or float element type.
+  assert(tdescTy.getElementType().isIntOrFloat() &&
+         "Expected int or float element type.");
+  // If the rank is 1, then return default layout for 1D vector.
+  if (tdescTy.getRank() == 1)
+    return getDefaultSIMTLayoutInfo(1);
+  // Packing factor is determined by the element type bitwidth.
+  unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
+
+  if (tdescTy.isScattered()) {
+    int packingFactor = xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth;
+    return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
+                      LaneData({1, packingFactor}));
+  }
+
+  int packingFactor = (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth: 1;
+  return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
+                    LaneData({1, packingFactor}));
+}
+
 /// Helper Function to get the expected layouts for DPAS operands. `lane_data`
 /// is set according to the following criteria:
 /// * For A operand, the data must be packed in minimum
@@ -379,8 +404,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
   // Here we assign the default layout to the tensor descriptor operand of
   // prefetch.
   auto tdescTy = prefetch.getTensorDescType();
-  auto prefetchLayout = getDefaultSIMTLayoutInfo(
-      VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
+  auto prefetchLayout = getDefaultSIMTLayoutInfo(tdescTy);
   // Propagate the layout to the source tensor descriptor.
   propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
 }
@@ -516,24 +540,14 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
 void LayoutInfoPropagation::visitLoadGatherOp(
     xegpu::LoadGatherOp load, ArrayRef<LayoutInfoLattice *> operands,
     ArrayRef<const LayoutInfoLattice *> results) {
-  LayoutInfo valueLayout = results[0]->getValue();
-  // Need the layout of the value to propagate to the tensor descriptor.
-  if (!valueLayout.isAssigned())
-    return;
+  // The layout is strictly determined by the tensor descriptor type.
+  LayoutInfo layout = getDefaultSIMTLayoutInfo(load.getTensorDescType());
 
-  LayoutInfo tensorDescLayout = valueLayout;
-  if (load.getTranspose()) {
-    // LoadGatherOp has the transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    load.emitWarning("Transpose effect is not expected for LoadGatherOp at "
-                     "LayoutInfoPropagation stage.");
-    tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
-  }
   // Mask operand should have 1D default layout.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
+
   // Propagate the new layout to the tensor descriptor operand.
-  propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(layout));
   // Propagate the new layout to the mask operand.
   propagateIfChanged(operands[1], operands[1]->meet(maskLayout));
 }
@@ -567,21 +581,13 @@ void LayoutInfoPropagation::visitStoreScatterOp(
         "Expected the first dimension of 2D tensor descriptor to be equal to "
         "subgroup size.");
 
-  LayoutInfo valueLayout =
-      getDefaultSIMTLayoutInfo(storeScatter.getValueType());
-  LayoutInfo storeScatterLayout = valueLayout;
-  if (storeScatter.getTranspose()) {
-    // StoreScatteOp allows transpose effect. However, at the stage of this
-    // analyis this effect is not expected and should be abstracted away. Emit
-    // a warning.
-    storeScatter.emitWarning("Transpose effect is not expected for "
-                             "StoreScatterOp at LayoutInfoPropagation stage.");
-    storeScatterLayout = valueLayout.getTransposedLayout({1, 0});
-  }
+  LayoutInfo layout =
+      getDefaultSIMTLayoutInfo(storeScatter.getTensorDescType());
+
   // Propagate the value layout.
-  propagateIfChanged(operands[0], operands[0]->meet(valueLayout));
+  propagateIfChanged(operands[0], operands[0]->meet(layout));
   // Propagate the tensor descriptor layout.
-  propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
+  propagateIfChanged(operands[1], operands[1]->meet(layout));
   // Use default 1D layout for mask operand.
   LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
   propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 0457f8128b908..9ab3ea828a033 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -507,8 +507,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
         for (int64_t i = 0; i < numNewChunks; ++i)
           convertedMasks.push_back(mask);
       }
-      // This is to handle the transpose effect when chunkSize > 1.
-      std::swap((*targetShape)[0], (*targetShape)[1]);
       newValueTy = valueTy.cloneWith(*targetShape, elemTy);
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
@@ -519,8 +517,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Value> newOps;
     for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
       auto newOp = rewriter.create<xegpu::LoadGatherOp>(
-          loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
       newOps.push_back(newOp);
     }
 
@@ -598,9 +595,6 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
           convertedMasks.push_back(mask);
         }
       }
-      // This is to handle the transpose effect when chunkSize > 1.
-      std::swap((*targetShape)[0], (*targetShape)[1]);
-
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
       convertedMasks =
@@ -617,8 +611,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
       Value t = convertedTdescs[i];
       Value m = op.getMask() ? convertedMasks[i] : nullptr;
       rewriter.create<xegpu::StoreScatterOp>(
-          loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
-          op.getL2HintAttr(), op.getL3HintAttr());
+          loc, v, t, m, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
     }
 
     rewriter.eraseOp(op);
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 054c4d12fdb28..5ceb548221758 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -199,8 +199,8 @@ gpu.func @simt_load_nd_7(%src: memref<24x32xf16>) {
 gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
   gpu.return
 }
 
@@ -235,8 +235,6 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
   gpu.return
 }
 
-
-
 // CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
@@ -248,7 +246,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
-
 // CHECK: func @simt_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
@@ -318,8 +315,8 @@ gpu.func @subgroup_load(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<4x2xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<4x2xf32>
   gpu.return
 }
 
@@ -370,8 +367,8 @@ gpu.func @subgroup_load_3(%src: ui64) {
   %1 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
   %2 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<8x4xf16>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<4xi1> -> vector<8x4xf16>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<4xi1> -> vector<4x8xf16>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<4xi1> -> vector<4x8xf16>
   gpu.return
 }
 
@@ -394,17 +391,15 @@ gpu.func @subgroup_store(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
   %1 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<2x4xf32>
-  %2 = arith.constant dense<2.9>: vector<2x4xf32>
+  //CHECK: %[[cst2:.*]] = arith.constant dense<2.900000e+00> : vector<4x2xf32>
+  %2 = arith.constant dense<2.9>: vector<4x2xf32>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
-  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   gpu.return
 }
 
-
-
 // CHECK: gpu.func @simt_store(%[[arg0:.*]]: ui64) {
 gpu.func @simt_store(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -426,17 +421,15 @@ gpu.func @subgroup_store_2(%src: ui64) {
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   //CHECK: %[[cst1:.*]] = arith.constant dense<true> : vector<4xi1>
   %1 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<2x4xf16>
-  %2 = arith.constant dense<2.9>: vector<2x4xf16>
+  //CHECK: %[[cst2:.*]] = arith.constant {{.*}} : vector<4x2xf16>
+  %2 = arith.constant dense<2.9>: vector<4x2xf16>
   //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   %3 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
-  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x4xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
+  //CHECK: xegpu.store %[[cst2]], %[[R0]], %[[cst1]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<4x2xf16>, !xegpu.tensor_desc<4x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   gpu.return
 }
 
-
-
 // CHECK: gpu.func @simt_store_2(%[[arg0:.*]]: ui64) {
 gpu.func @simt_store_2(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -467,7 +460,6 @@ gpu.func @subgroup_store_3(%src: ui64) {
   gpu.return
 }
 
-
 // CHECK: gpu.func @simt_store_3(%[[arg0:.*]]: ui64) {
 gpu.func @simt_store_3(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 429081079de1e..80d309690c9b2 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -90,27 +90,18 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
 }
 
 // -----
-// CHECK-LABEL: func.func @load_gather_with_transpose_effect(
-// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
-// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
-// CHECK-SAME:  dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
-// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
-// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{transpose}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
-// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
-func.func @load_gather_with_transpose_effect(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
-  %c0 = arith.constant 0 : index
-  %0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
-  %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
-  %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
-  %cst_0 = arith.constant dense<true> : vector<16xi1>
-  %2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
-  %3 = xegpu.load %2, %cst_0 <{transpose}> : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
-  %4 = xegpu.dpas %1, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
-  %5 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
-  xegpu.store_nd %4, %5  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  return
+// CHECK-LABEL: func.func @load_gather_with_chunksize
+// CHECK-SAME: [[arg0:%.+]]: memref<256xf16>
+// CHECK: [[idx:%.+]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+// CHECK: [[m:%.+]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
+// CHECK: [[desc:%.+]] = xegpu.create_tdesc [[arg0]], [[idx]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
+// CHECK: xegpu.load [[desc]], [[m]]  : !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x8xf16>
+func.func @load_gather_with_chunksize(%arg0: memref<256xf16>) -> vector<16x8xf16> {
+  %index = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
+  %mask = arith.constant dense<true> : vector<16xi1>
+  %1 = xegpu.create_tdesc %arg0, %index : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
+  %2 = xegpu.load %1, %mask : !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1> -> vector<16x8xf16>
+  return %2: vector<16x8xf16>
 }
 
 // -----
@@ -127,24 +118,24 @@ func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf
   %cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %0 = xegpu.create_tdesc %arg0, %cst : memref<256xf32>, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  %1 = xegpu.load %0, %cst_0  : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  %1 = xegpu.load %0, %cst_0 : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
   xegpu.store_nd %1, %arg1  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
   return
 }
 
 // -----
-// CHECK-LABEL: func.func @store_scatter_with_transpose_effect(
+// CHECK-LABEL: func.func @store_scatter_with_chunksize(
 // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<128xf32>) {
 // CHECK: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %{{.*}} : memref<128xf32>, vector<16xindex> ->
 // CHECK-SAME: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
-// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>,
+// CHECK-NEXT: xegpu.store %{{.*}}, %[[T0]], %{{.*}} : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>,
 // CHECK-SAME: #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>, vector<16xi1>
-func.func @store_scatter_with_transpose_effect(%arg0: memref<128xf32>) {
-  %cst = arith.constant dense<1.000000e+00> : vector<8x16xf32>
+func.func @store_scatter_with_chunksize(%arg0: memref<128xf32>) {
+  %cst = arith.constant dense<1.000000e+00> : vector<16x8xf32>
   %cst_0 = arith.constant dense<true> : vector<16xi1>
   %cst_1 = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
   %0 = xegpu.create_tdesc %arg0, %cst_1 : memref<128xf32>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
-  xegpu.store %cst, %0, %cst_0 <{transpose}> : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
+  xegpu.store %cst, %0, %cst_0 : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1>
   return
 }
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index f977ba3c11bcf..ac5fe89a67f9a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -368,35 +368,35 @@ gpu.module @test_kernel {
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
     128, 136, 144, 152, 160, 168, 176, 184,
-    192, 200, 208, 216, 224, 232, 240, 248 
+    192, 200, 208, 216, 224, 232, 240, 248
     ]> : vector<32xindex>
 
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
     xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-   
+
     %delta = arith.constant dense<[
     32,   32,  32,  32,  32,  32,  32,  32,
     32,   32,  32,  32,  32,  32,  32,  64,
     128, 128, 128, 128, 128, 128, 128, 128,
-    128, 128, 128, 128, 128, 128, 128, 256 
+    128, 128, 128, 128, 128, 128, 128, 256
     ]> : vector<32xindex>
     %new_tdesc = xegpu.update_offset %tdesc, %delta
-              : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
- 
+              : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
+
     %c17 = arith.constant 17: index
     %mask = vector.create_mask %c17: vector<32xi1>
 
     %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
 
     %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
-    xegpu.store %st_vec, %tdesc, %mask: 
-                 vector<32xf32>, 
-                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
+    xegpu.store %st_vec, %tdesc, %mask:
+                 vector<32xf32>,
+                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
                  vector<32xi1>
-  
+
     gpu.return
   }
-  
+
 }
 
 // -----
@@ -407,8 +407,8 @@ gpu.module @test_kernel   {
   // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
    // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
-   // CHECK-COUNT-4: xegpu.load  {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
-  // CHECK-COUNT-4: xegpu.store  {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+   // CHECK-COUNT-4: xegpu.load  {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<16x2xf32>
+  // CHECK-COUNT-4: xegpu.store  {{.*}} : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
 
   gpu.func @test_prefetch_load_store_update_chunk(%src: ui64)  {
 
@@ -416,32 +416,32 @@ gpu.module @test_kernel   {
       0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
       128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
+      192, 200, 208, 216, 224, 232, 240, 248
     ]> : vector<32xindex>
 
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
     xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
-   
+
     %delta = arith.constant dense<[
       32,   32,  32,  32,  32,  32,  32,  32,
       32,   32,  32,  32,  32,  32,  32,  64,
       128, 128, 128, 128, 128, 128, 128, 128,
-      128, 128, 128, 128, 128, 128, 128, 256 
+      128, 128, 128, 128, 128, 128, 128, 256
     ]> : vector<32xindex>
     %new_tdesc = xegpu.update_offset %tdesc, %delta
-              : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>     
- 
+              : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
+
     %c17 = arith.constant 17: index
     %mask = vector.create_mask %c17: vector<32xi1>
 
-    %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
+    %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<32x4xf32>
 
-    %st_vec = arith.addf %ld_vec, %ld_vec : vector<4x32xf32>
-    xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: 
-                 vector<4x32xf32>, 
-                 !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, 
+    %st_vec = arith.addf %ld_vec, %ld_vec : vector<32x4xf32>
+    xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>:
+                 vector<32x4xf32>,
+                 !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>,
                  vector<32xi1>
-  
+
     gpu.return
   }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 41414d802f212..6999da5d222fe 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -169,7 +169,7 @@ gpu.module @test {
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
     128, 136, 144, 152, 160, 168, 176, 184,
-    192, 200, 208, 216, 224, 232, 240, 248 
+    192, 200, 208, 216, 224, 232, 240, 248
     ]> : vector<32xindex>
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
     gpu.return %tdesc : !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>,  #xegpu.layout<inst_data = [16]>>
@@ -199,15 +199,15 @@ gpu.module @test {
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
     128, 136, 144, 152, 160, 168, 176, 184,
-    192, 200, 208, 216, 224, 232, 240, 248 
+    192, 200, 208, 216, 224, 232, 240, 248
     ]> : vector<32xindex>
-      
+
     %c17 = arith.constant 17: index
     %mask = vector.create_mask %c17: vector<32xi1>
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
     %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
-      
-    gpu.return %ld : vector<32xf32> 
+
+    gpu.return %ld : vector<32xf32>
   }
 
 //-----
@@ -222,7 +222,7 @@ gpu.module @test {
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
     128, 136, 144, 152, 160, 168, 176, 184,
-    192, 200, 208, 216, 224, 232, 240, 248 
+    192, 200, 208, 216, 224, 232, 240, 248
     ]> : vector<32xindex>
 
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
@@ -242,16 +242,16 @@ gpu.module @test {
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
     128, 136, 144, 152, 160, 168, 176, 184,
-    192, 200, 208, 216, 224, 232, 240, 248 
+    192, 200, 208, 216, 224, 232, 240, 248
     ]> : vector<32xindex>
-    
+
     %c17 = arith.constant 17: index
     %mask = vector.create_mask %c17: vector<32xi1>
 
     %st_vec = arith.constant dense<1023.0>: vector<32xf32>
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
     xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
-    
+
     gpu.return
   }
 
@@ -280,7 +280,7 @@ gpu.module @test {
   }
 
 // CHECK-LABEL: create_tdesc_step_chunk3
-  // CHECK-SAME: [[arg0:%.+]]: ui64  
+  // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
   // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
@@ -300,45 +300,45 @@ gpu.module @test {
   // CHECK-LABEL: load_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  // CHECK-COUNT-4: xegpu.load  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
+  // CHECK-COUNT-4: xegpu.load  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<16x2xf32>
 
-  gpu.func @load_chunk(%src: ui64) -> vector<4x32xf32> {
+  gpu.func @load_chunk(%src: ui64) -> vector<32x4xf32> {
     %cst = arith.constant dense<[
         0,   8,  16,  24,  32,  40,  48,  56,
         64,  72,  80,  88,  96, 104, 112, 120,
         128, 136, 144, 152, 160, 168, 176, 184,
-        192, 200, 208, 216, 224, 232, 240, 248 
+        192, 200, 208, 216, 224, 232, 240, 248
     ]> : vector<32xindex>
-    
+
     %c17 = arith.constant 17: index
     %mask = vector.create_mask %c17: vector<32xi1>
 
-    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> 
-    %ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
-    
-    gpu.return %ld : vector<4x32xf32> 
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    %ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<32x4xf32>
+
+    gpu.return %ld : vector<32x4xf32>
    }
 
 //-----
   // CHECK-LABEL: store_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} :  ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  // CHECK-COUNT-4: xegpu.store  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+  // CHECK-COUNT-4: xegpu.store  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<16x2xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
   gpu.func @store_chunk(%src: ui64) {
     %cst = arith.constant dense<[
       0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
       128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
+      192, 200, 208, 216, 224, 232, 240, 248
     ]> : vector<32xindex>
-    
+
     %c17 = arith.constant 17: index
     %mask = vector.create_mask %c17: vector<32xi1>
 
-    %st_vec = arith.constant dense<1023.>: vector<4x32xf32>
+    %st_vec = arith.constant dense<1023.>: vector<32x4xf32>
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
-    xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: vector<4x32xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16,2]>>, vector<32xi1>
-    
+    xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32x4xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16,2]>>, vector<32xi1>
+
     gpu.return
   }
 
@@ -352,11 +352,11 @@ gpu.module @test {
       0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
       128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
+      192, 200, 208, 216, 224, 232, 240, 248
       ]> : vector<32xindex>
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
     xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
-    
+
     gpu.return
   }
 
@@ -370,7 +370,7 @@ gpu.module @test {
       0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
       128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
+      192, 200, 208, 216, 224, 232, 240, 248
     ]> : vector<32xindex>
     %delta = arith.constant dense<32>: vector<32xindex>
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
@@ -379,6 +379,6 @@ gpu.module @test {
         : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
 
     gpu.return %new_tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
-  }  
+  }
 }
 

>From d9e1cbd022cb146000d8d010f2aae23aa0d848fb Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 23 Jun 2025 18:55:56 +0000
Subject: [PATCH 2/6] format

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp                    | 3 ++-
 .../lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp | 8 ++++++--
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp         | 8 +++++---
 3 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 51b06d7c5ce3f..f0fb03d4f1139 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -68,7 +68,8 @@ static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
 }
 
 static LogicalResult
-isValidGatherScatterParams(Type maskTy, VectorType valueTy, TensorDescType tdescTy,
+isValidGatherScatterParams(Type maskTy, VectorType valueTy,
+                           TensorDescType tdescTy,
                            function_ref<InFlightDiagnostic()> emitError) {
 
   if (!tdescTy.isScattered())
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index d713610ca9575..60ccd823775a5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -228,12 +228,16 @@ static LayoutInfo getDefaultSIMTLayoutInfo(xegpu::TensorDescType tdescTy) {
   unsigned bitwidth = tdescTy.getElementType().getIntOrFloatBitWidth();
 
   if (tdescTy.isScattered()) {
-    int packingFactor = xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth;
+    int packingFactor =
+        xegpu::targetinfo::packedSizeInBitsForGatherScatter / bitwidth;
     return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize, 1}),
                       LaneData({1, packingFactor}));
   }
 
-  int packingFactor = (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault) ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth: 1;
+  int packingFactor =
+      (bitwidth < xegpu::targetinfo::packedSizeInBitsForDefault)
+          ? xegpu::targetinfo::packedSizeInBitsForDefault / bitwidth
+          : 1;
   return LayoutInfo(LaneLayout({1, xegpu::targetinfo::subgroupSize}),
                     LaneData({1, packingFactor}));
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 9ab3ea828a033..be39ee1f0b53f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -517,7 +517,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Value> newOps;
     for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
       auto newOp = rewriter.create<xegpu::LoadGatherOp>(
-          loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+          loc, newValueTy, t, m, op.getL1HintAttr(), op.getL2HintAttr(),
+          op.getL3HintAttr());
       newOps.push_back(newOp);
     }
 
@@ -610,8 +611,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
       Value v = convertedValues[i];
       Value t = convertedTdescs[i];
       Value m = op.getMask() ? convertedMasks[i] : nullptr;
-      rewriter.create<xegpu::StoreScatterOp>(
-          loc, v, t, m, op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+      rewriter.create<xegpu::StoreScatterOp>(loc, v, t, m, op.getL1HintAttr(),
+                                             op.getL2HintAttr(),
+                                             op.getL3HintAttr());
     }
 
     rewriter.eraseOp(op);

>From 581b97c12bb61b89764b14501c734d9611f4cd67 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 23 Jun 2025 19:14:40 +0000
Subject: [PATCH 3/6] cleanup the doc

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 80 +++----------------
 1 file changed, 11 insertions(+), 69 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index ffc08e9b90b56..e96d867c8d671 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -80,9 +80,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
         information e.g., memref<?x?xf16>, the strides information has to be explicitly
         passed via the "strides" and "const_strides" argument.
 
-    In SIMT mode, tensor descriptor is augmented with `LayoutAttr` which describes the
-    mapping of the tensor descriptor to the work items.
-
     Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
     ```mlir
     %0 = memref.alloc() : memref<1024x1024xf32>
@@ -106,15 +103,6 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     %c1 = arith.constant 1 : index
     %1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
     ```
-
-    Example 4 (SIMT mode):
-    ```mlir
-    %0 = memref.alloc() : memref<1024x1024xf32>
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 8 : index
-    %1 = xegpu.create_nd_tdesc %0[%c0, %c0] : memref<1024x1024xf32>
-          -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-    ```
   }];
 
   let arguments = (ins
@@ -301,9 +289,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
     fp32 or fp64. It implies that vnni and transpose cannot exit at the
     same time.
 
-    In SIMT mode, LoadNdOp expects the tensor descriptor to be augmented with `LayoutAttr`
-    which describes the mapping of the tensor to the work items. In this case, result
-    vector represents the data to be loaded by each work-item.
+    In SIMT mode, result vector represents the data to be loaded by each work-item.
 
     Example 1:
     ```mlir
@@ -317,8 +303,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
     ```mlir
       xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>,
                         l2_hint = #xegpu.cache_hint<uncached>}>
-        : !xegpu.tensor_desc<8x16xf32,
-          #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x1xf32>
+        : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
     ```
 
 
@@ -359,9 +344,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
     of cache, L1, L2 and L3. If hardware does not have a correspoding cache,
     Corresponding cache hint attribute will be masked.
 
-    In SIMT mode, StoreNdOp expects the tensor descriptor to be augmented with `LayoutAttr`
-    which describes the mapping of the tensor to the work items. In this case, input
-    vector represents the data to be stored by each work-item.
+    In SIMT mode, the input vector represents the data to be stored by each work-item.
 
     Example 1:
     ```mlir
@@ -375,8 +358,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
       xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                              l2_hint = #xegpu.cache_hint<write_back>,
                              l3_hint = #xegpu.cache_hint<write_through>}
-                             : vector<8x1xf16>, !xegpu.tensor_desc<8x16xf16,
-                               #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+                             : vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
     ```
 
 
@@ -410,15 +392,10 @@ def XeGPU_UpdateNdOffsetOp : XeGPU_Op<"update_nd_offset",
     The offsets are relative offset to the current position in the number
     of elements. It will result in a same type TensorDesc as the input.
 
-  Example 1:
+  Example:
   ```
     %2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>
   ```
-  Example 2 (SIMT mode):
-  ```
-    %2 = xegpu.update_nd_offset %1, [0, 16]:
-      !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  ```
   }];
 
   let arguments = (ins
@@ -476,11 +453,6 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
     match the dimension of offsets. It may also has a second dimension corresponding to
     the chunk_size if the chunk size is larger than 1.
 
-    In SIMT mode, similar to `create_nd_tdesc` the resulting tensor descriptor is augmented
-    with `LayoutAttr` which describes the mapping of the tensor descriptor to the work items.
-    In this case, the first dimension of the tensor descriptor represents the work-items, and
-    the second dimension represents the chunk size.
-
     Example 1: It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
     ```mlir
     %a = memref.alloc() : memref<1024xf32>
@@ -505,15 +477,6 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
     %1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
           -> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
     ```
-
-    Example 4: SIMT mode
-    ```mlir
-    %0 = memref.alloc() : memref<1024xf32>
-    %off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
-    %1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
-          -> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>,
-          #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
-    ```
   }];
 
   let arguments = (ins XeGPU_BaseAddrType: $source,
@@ -614,10 +577,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
     The mask operand masks out memory access so that it is safe to pass out-of-boundary
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
-    In SIMT mode, LoadGatherOp expects the tensor descriptor to be augmented with `LayoutAttr`
-    which describes the mapping of the tensor to the work items. In this case, result vector
-    represents the data to be loaded by each work-item. Each work-item recieves a `chunk_size`
-    number of elements.
+    In SIMT mode, the result vector represents the data to be loaded by each work-item.
+    Each work-item recieves a `chunk_size` number of elements.
 
   Example 1:
   ```mlir
@@ -641,8 +602,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
     %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
-          : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>,
-            !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>
+          : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
             vector<16xi1> -> vector<8xf32>
   ```
 
@@ -692,10 +652,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
   has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
   introduced on purpose, making sure users are aware of this implicit transformation.
 
-  In SIMT mode, StoreScatterOp expects the tensor descriptor to be augmented with `LayoutAttr`
-  which describes the mapping of the tensor to the work items. In this case, input vector
-  represents the data to be stored by each work-item. Each work-item recieves a `chunk_size`
-  number of elements.
+  In SIMT mode, the input vector represents the data to be stored by each work-item.
+  Each work-item stores a `chunk_size` number of elements.
 
   Example 1:
   ```mlir
@@ -712,15 +670,6 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
                             l3_hint = #xegpu.cache_hint<write_through>}
           : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
   ```
-  Example 3 (SIMT mode):
-  ```mlir
-    xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
-                            l2_hint = #xegpu.cache_hint<write_back>,
-                            l3_hint = #xegpu.cache_hint<write_through>}
-          : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
-            !xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> vector<16xi1>
-  ```
-
   }];
 
   let arguments = (ins
@@ -763,20 +712,13 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
     update the offset per work-item, so its offsets contains values representing
     shifts for each work-item.
 
-    Example 1:
+    Example:
     ```mlir
       %off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
       %2 = xegpu.update_offset %1, %off :
               !xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>>, vector<4xindex>
     ```
 
-    Example 2 (SIMT mode):
-    ```mlir
-      %off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
-      %2 = xegpu.update_offset %1, %off :
-              !xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>,
-              #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>, vector<4xindex>
-    ```
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,

>From d5b5643dabb2080d5298da67d8cddbe53f8bf29a Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 23 Jun 2025 19:26:08 +0000
Subject: [PATCH 4/6] remove unnecessary change

---
 mlir/test/Dialect/XeGPU/ops.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 5ceb548221758..aff8f63adc05b 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -199,8 +199,8 @@ gpu.func @simt_load_nd_7(%src: memref<24x32xf16>) {
 gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
   %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
-  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
-  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<16x8xf32> -> vector<16x8xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+  %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
   gpu.return
 }
 

>From 3a670831fd7ee145b9044452928d2cb86a0889b7 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 23 Jun 2025 19:58:50 +0000
Subject: [PATCH 5/6] refactor

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 25 +++++++++++++---------
 1 file changed, 15 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 6790c5e3af2c0..87125cbaacd89 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -309,11 +310,23 @@ LogicalResult TensorDescType::verify(
     llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
     mlir::Attribute encoding, mlir::Attribute layout) {
   size_t rank = shape.size();
-  // Low-precision types are packed in 32-bit units.
-  int32_t packingFactor = 32 / elementType.getIntOrFloatBitWidth();
   if (rank != 1 && rank != 2)
     return emitError() << "expected 1D or 2D tensor";
 
+  auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
+  if (blockAttr) {
+    MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
+    if (rank == 2 && memorySpaceAttr &&
+        memorySpaceAttr.getValue() == MemorySpace::SLM)
+      return emitError() << "SLM is not supported for 2D block tensor";
+  }
+
+  // for gather and scatter ops, Low-precision types are packed in 32-bit units.
+  unsigned bitWidth = elementType.getIntOrFloatBitWidth();
+  int packingFactor =
+      bitWidth < targetinfo::packedSizeInBitsForGatherScatter
+          ? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
+          : 1;
   auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
   if (scatterAttr) {
     // Expected tensor ranks for scattered data:
@@ -336,14 +349,6 @@ LogicalResult TensorDescType::verify(
     }
   }
 
-  auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
-  if (blockAttr) {
-    MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
-    if (rank == 2 && memorySpaceAttr &&
-        memorySpaceAttr.getValue() == MemorySpace::SLM)
-      return emitError() << "SLM is not supported for 2D block tensor";
-  }
-
   auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
   if (layoutAttr) {
     if (rank != (size_t)layoutAttr.getRank())

>From 0fe0f4b52453f308c4bd239a000166809b9cbb37 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 24 Jun 2025 18:25:08 +0000
Subject: [PATCH 6/6] fix typos

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e96d867c8d671..2c2e640f162b7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -595,7 +595,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
-            vector<16xi1> -> vector<8x16xf32>
+            vector<16xi1> -> vector<16x8xf32>
   ```
   Example 3 (SIMT mode):
   ```mlir
@@ -668,7 +668,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
     xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                             l2_hint = #xegpu.cache_hint<write_back>,
                             l3_hint = #xegpu.cache_hint<write_through>}
-          : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
+          : vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
   ```
   }];
 



More information about the Mlir-commits mailing list