[Mlir-commits] [mlir] [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (PR #144447)

Jianhui Li llvmlistbot at llvm.org
Tue Jun 17 14:51:25 PDT 2025


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/144447

>From bac8bc607a0b5f9171472375e5f42c24a5ad429d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Jun 2025 14:40:41 +0000
Subject: [PATCH 01/16] add unroll supportfor reduce and broadcast

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |   6 +
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   | 110 ++++++++++++++++++
 2 files changed, 116 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7cd998eed2e08..a3826c56e1f62 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -169,6 +169,12 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
   if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
     return getTileShape(op->getOpResult(0));
 
+  if (isa<vector::MultiDimReductionOp>(op))
+    return getTileShape(op->getOpOperand(0));
+
+  if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
+    return getTileShape(op->getOpResult(0));
+
   return std::nullopt;
 }
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index f9114988686c8..8e3673d04eacb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -246,3 +246,113 @@ gpu.module @test_kernel {
     gpu.return
   }
 }
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 16]>
+#r = #xegpu.layout<inst_data = [16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+    %acc = arith.constant dense<0.0> : vector<64xf32>
+    %c64 = arith.constant 64 : index
+    %block_id_x = gpu.block_id x
+    %m = arith.muli %block_id_x, %c64 : index
+    %0 = xegpu.create_nd_tdesc %a[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<16x64xf32, #l> -> vector<16x64xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, [[ACC:%[0-9A-Za-z]+]] [0] : vector<16x16xf32> to vector<16xf32>
+    // CHECK-COUNT-3: vector.multi_reduction <add>, {{.*}}, [[ACC]] [0] : vector<16x16xf32> to vector<16xf32>
+    %2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [0]: vector<16x64xf32> to vector<64xf32>
+    %3 = xegpu.create_nd_tdesc %b[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
+    xegpu.store_nd %2, %3: vector<64xf32>, !xegpu.tensor_desc<64xf32, #r>
+    gpu.return
+  }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 16]>
+#r = #xegpu.layout<inst_data = [16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+    %c1 = arith.constant 1 : index
+    %c32 = arith.constant 32 : index
+    %acc = arith.constant dense<0.0> : vector<32xf32>
+
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+
+    %m = arith.muli %block_id_x, %c32 : index
+    %n = arith.muli %block_id_y, %c32 : index
+    %0 = xegpu.create_nd_tdesc %a[%m, %n] : memref<512x32xf32> -> !xegpu.tensor_desc<32x128xf32, #l>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x128xf32, #l> -> vector<32x128xf32>
+
+    // CHECK: vector.multi_reduction <add>, {{.*}}, [[INIT:%[0-9A-Za-z]+]] [1] : vector<16x16xf32> to vector<16xf32>
+    // CHECK-COUNT-1: vector.multi_reduction <add>, {{.*}}, [[INIT]] [1] : vector<16x16xf32> to vector<16xf32>
+
+    %2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [1]: vector<32x128xf32> to vector<32xf32>
+    %3 = xegpu.create_nd_tdesc %b[%n] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
+    xegpu.store_nd %2, %3: vector<32xf32>, !xegpu.tensor_desc<32xf32, #r>
+    gpu.return
+  }
+}
+
+// -----
+#r = #xegpu.layout<inst_data = [16]>
+#l = #xegpu.layout<inst_data = [16, 16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+    %c64 = arith.constant 64 : index
+    %block_id_x = gpu.block_id x
+    %m = arith.muli %block_id_x, %c64 : index
+    %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<64xf32, #r> -> vector<64xf32>
+    // CHECK-COUNT-4: vector.broadcast {{.*}} : vector<16xf32> to vector<16x16xf32>
+    %2 = vector.broadcast  %1 {layout_result_0 = #l} : vector<64xf32> to vector<16x64xf32>
+    %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
+    xegpu.store_nd %2, %3: vector<16x64xf32>, !xegpu.tensor_desc<16x64xf32, #l>
+    gpu.return
+  }
+}
+
+// -----
+#r = #xegpu.layout<inst_data = [16]>
+#l = #xegpu.layout<inst_data = [16, 16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+    %c32 = arith.constant 32 : index
+    %block_id_x = gpu.block_id x
+    %m = arith.muli %block_id_x, %c32 : index
+    %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32xf32, #r> -> vector<32xf32>
+    %11 = vector.shape_cast %1 :  vector<32xf32> to vector<32x1xf32>
+    // CHECK-COUNT-8: vector.broadcast {{.*}}: vector<16x1xf32> to vector<16x16xf32>
+    %2 = vector.broadcast  %11 {layout_result_0 = #l} : vector<32x1xf32> to vector<32x64xf32>
+    %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<32x64xf32, #l>
+    xegpu.store_nd %2, %3: vector<32x64xf32>, !xegpu.tensor_desc<32x64xf32, #l>
+    gpu.return
+  }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 8]>
+#t = #xegpu.layout<inst_data = [8, 16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+    %c32 = arith.constant 32 : index
+    %block_id_x = gpu.block_id x
+    %m = arith.muli %block_id_x, %c32 : index
+    %0 = xegpu.create_nd_tdesc %a[%m, 0] : memref<512x8xf32> -> !xegpu.tensor_desc<32x8xf32, #l>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x8xf32, #l> -> vector<32x8xf32>
+    // CHECK-COUNT-2: vector.transpose {{.*}}  [1, 0] : vector<16x8xf32> to vector<8x16xf32>
+    %2 = vector.transpose  %1, [1, 0] {layout_result_0 = #t} : vector<32x8xf32> to vector<8x32xf32>
+    %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<8x512xf32> -> !xegpu.tensor_desc<8x32xf32, #t>
+    xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
+    gpu.return
+  }
+}
\ No newline at end of file

>From 30b099ef6d1f8d036290590c73786bd45ea51b57 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Jun 2025 23:14:58 +0000
Subject: [PATCH 02/16] add create_desc unrolling and test

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 41 ++++++++++++++++++-
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 24 +++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 885477fe4cbd5..672b0fb731f31 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -396,11 +396,50 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
   }
 };
 
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+    
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+
+    VectorType indiceVecTy = indiceVec.getType();
+
+    SmallVector<Type> convertedIndiceTypes =
+        getUnrolledTypes(indiceVecTy, *targetShape);
+
+    SmallVector<Value> convertedIndiceVec =
+        pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+
+    for (auto indice : convertedIndiceVec) {
+      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
   patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
-               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
+               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, 
+               UnrollCreateDescOp>(
       patterns.getContext(), options);
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 3f3461e92bc08..abdee098ab430 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,6 +71,30 @@ struct TestXeGPUUnrollingPatterns
             }
           }
 
+          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp,
+                  xegpu::PrefetchOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+            xegpu::TensorDescType tdescTy;
+            if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+              tdescTy = createOp.getType();
+            } else if (auto updateOp =
+                           dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+              tdescTy = updateOp.getTensorDescType();
+            } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+              tdescTy = prefetchOp.getTensorDescType();
+            } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+              tdescTy = loadOp.getTensorDescType();
+            } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+              tdescTy = storeOp.getTensorDescType();
+            }
+
+            if (auto layout = tdescTy.getLayoutAttr()) {
+              auto inst_data = layout.getInstData();
+              if (inst_data && layout.isSgLayout())
+                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                            inst_data.asArrayRef().end());
+            }
+          }
+
           if (isa<xegpu::DpasOp>(op))
             return SmallVector<int64_t>{8, 16, 16};
 

>From 8a0e1455683d2a04744bb653348e39c353060704 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Jun 2025 19:49:39 +0000
Subject: [PATCH 03/16] add unrolling support scatter operations

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 159 +++++++++++++++++-
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 143 ++++++++++++++++
 2 files changed, 296 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 672b0fb731f31..96b86f6509419 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -409,19 +409,14 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
 
     auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
 
-    
     TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
-
     VectorType indiceVecTy = indiceVec.getType();
-
     SmallVector<Type> convertedIndiceTypes =
         getUnrolledTypes(indiceVecTy, *targetShape);
-
     SmallVector<Value> convertedIndiceVec =
         pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
 
     SmallVector<Value> newOps;
-
     for (auto indice : convertedIndiceVec) {
       auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
       newOps.push_back(newOp);
@@ -434,12 +429,164 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
   }
 };
 
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    Type elemTy = tdescTy.getElementType();
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Type> convertedMaskTypes =
+        getUnrolledTypes(maskTy, *targetShape);
+    SmallVector<Value> convertedMasks = pack(
+        op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+      auto newOp =
+          rewriter.create<xegpu::LoadGatherOp>(loc, newValueTy, t, m,  
+            op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+  using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    for (auto t : convertedTdesc)
+      rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    VectorType maskTy;
+    if (op.getMask())
+      maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    SmallVector<Value> convertedTdescs =
+        pack(op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> convertedMasks;
+    if (op.getMask()) {
+      SmallVector<Type> convertedMaskTypes =
+          getUnrolledTypes(maskTy, *targetShape);
+      convertedMasks =
+          pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    }
+
+    for (size_t i = 0; i < convertedValues.size(); ++i) {
+      Value v = convertedValues[i];
+      Value t = convertedTdescs[i];
+      Value m = op.getMask() ? convertedMasks[i] : nullptr;
+      rewriter.create<xegpu::StoreScatterOp>(
+          loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+    }
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
+  using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
+    VectorType offsetVecTy = offsetVec.getType();
+    SmallVector<Type> convertedOffsetTypes =
+        getUnrolledTypes(offsetVecTy, *targetShape);
+    SmallVector<Value> convertedOffsetVec =
+        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);        
+
+    SmallVector<Value> newOps;
+    for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
+      auto newOp = rewriter.create<xegpu::UpdateOffsetOp>(
+          loc, t.getType(), t, o);
+      newOps.push_back(newOp);
+    }
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
   patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
                UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, 
-               UnrollCreateDescOp>(
+               UnrollCreateDescOp, UnrollLoadGatherOp, 
+               UnrollStoreScatterOp, UnrollPrefetchOp, UnrollUpdateOffsetOp>(
       patterns.getContext(), options);
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index b911bb3bbdc1c..47c54bfcb89d0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -158,4 +158,147 @@ gpu.module @test {
     %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
     gpu.return %c : vector<32x32xf32>
   }
+
+//-----
+
+  // CHECK-LABEL: test_create_tdesc_vec
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>,  #xegpu.layout<inst_data = [16]>>
+  }
+
+//-----
+
+  // CHECK-LABEL: test_create_tdesc_step
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+    %step = arith.constant dense<8> : vector<32xindex>
+    %seq = vector.step  : vector<32xindex>
+    %cst = arith.muli %seq, %step : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+  }
+
+//-----
+
+    // CHECK-LABEL: test_load
+    // CHECK-SAME: [[arg0:%.+]]: ui64
+    // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+    // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+    gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+      
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+      %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+      
+      gpu.return %ld : vector<32xf32> 
+    }
+
+//-----
+
+  // CHECK-LABEL: test_prefetch
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_prefetch(%src: ui64)  {
+
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+
+      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+      gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: test_store
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+  gpu.func @test_store(%src: ui64) {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %st_vec = arith.constant dense<1023.>: vector<32xf32>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
+    
+    gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: test_prefetch_load_store_update
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+   // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+  gpu.func @test_prefetch_load_store_update(%src: ui64)  {
+
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+
+      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+   
+      %delta = arith.constant dense<[
+      32,   32,  32,  32,  32,  32,  32,  32,
+      32,   32,  32,  32,  32,  32,  32,  64,
+      128, 128, 128, 128, 128, 128, 128, 128,
+      128, 128, 128, 128, 128, 128, 128, 256 
+      ]> : vector<32xindex>
+      %new_tdesc = xegpu.update_offset %tdesc, %delta
+           : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
+ 
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+      %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+      xegpu.store %st_vec, %tdesc, %mask: 
+                 vector<32xf32>, 
+                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
+                 vector<32xi1>
+  
+      gpu.return
+    }
 }

>From c91156c7fc789548676f367919ddd577ba06c5ed Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Jun 2025 19:51:32 +0000
Subject: [PATCH 04/16] clang format

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 31 ++++++++++---------
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  7 ++---
 2 files changed, 19 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 96b86f6509419..900ade8c171d5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -418,7 +418,8 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
 
     SmallVector<Value> newOps;
     for (auto indice : convertedIndiceVec) {
-      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
+      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+                                                        op.getSource(), indice);
       newOps.push_back(newOp);
     }
 
@@ -454,14 +455,14 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
 
     SmallVector<Type> convertedMaskTypes =
         getUnrolledTypes(maskTy, *targetShape);
-    SmallVector<Value> convertedMasks = pack(
-        op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    SmallVector<Value> convertedMasks =
+        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
 
     SmallVector<Value> newOps;
     for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
-      auto newOp =
-          rewriter.create<xegpu::LoadGatherOp>(loc, newValueTy, t, m,  
-            op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+      auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+          loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
       newOps.push_back(newOp);
     }
 
@@ -520,8 +521,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
 
     SmallVector<Value> convertedValues =
         pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
-    SmallVector<Value> convertedTdescs =
-        pack(op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
     SmallVector<Value> convertedMasks;
     if (op.getMask()) {
@@ -566,12 +567,12 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     SmallVector<Type> convertedOffsetTypes =
         getUnrolledTypes(offsetVecTy, *targetShape);
     SmallVector<Value> convertedOffsetVec =
-        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);        
+        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
 
     SmallVector<Value> newOps;
     for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
-      auto newOp = rewriter.create<xegpu::UpdateOffsetOp>(
-          loc, t.getType(), t, o);
+      auto newOp =
+          rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
       newOps.push_back(newOp);
     }
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
@@ -585,8 +586,8 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
   patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
-               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, 
-               UnrollCreateDescOp, UnrollLoadGatherOp, 
-               UnrollStoreScatterOp, UnrollPrefetchOp, UnrollUpdateOffsetOp>(
-      patterns.getContext(), options);
+               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+               UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
+               UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
+                                                       options);
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index abdee098ab430..57aaecbd7962f 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,13 +71,12 @@ struct TestXeGPUUnrollingPatterns
             }
           }
 
-          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp,
-                  xegpu::PrefetchOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
             xegpu::TensorDescType tdescTy;
             if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
               tdescTy = createOp.getType();
-            } else if (auto updateOp =
-                           dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+            } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
               tdescTy = updateOp.getTensorDescType();
             } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
               tdescTy = prefetchOp.getTensorDescType();

>From 30cb8d816352b0d29578f408777a151c94b5f3be Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 12 Jun 2025 20:59:47 -0700
Subject: [PATCH 05/16] Update
 mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir

Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>
---
 mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 47c54bfcb89d0..cca5339b086e4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -249,7 +249,7 @@ gpu.module @test {
     %c17 = arith.constant 17: index
     %mask = vector.create_mask %c17: vector<32xi1>
 
-    %st_vec = arith.constant dense<1023.>: vector<32xf32>
+    %st_vec = arith.constant dense<1023.0>: vector<32xf32>
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
     xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
     

>From f493e52fefa594aa5918471bf0aa57b2cc0331a4 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 12 Jun 2025 21:06:13 -0700
Subject: [PATCH 06/16] Update xegpu-unroll-patterns.mlir

correct indentation
---
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 100 +++++++++---------
 1 file changed, 49 insertions(+), 51 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index cca5339b086e4..52ec3b856da49 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -190,26 +190,25 @@ gpu.module @test {
 
 //-----
 
-    // CHECK-LABEL: test_load
-    // CHECK-SAME: [[arg0:%.+]]: ui64
-    // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-    // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-    gpu.func @test_load(%src: ui64) -> vector<32xf32> {
-      %cst = arith.constant dense<[
-      0,   8,  16,  24,  32,  40,  48,  56,
-      64,  72,  80,  88,  96, 104, 112, 120,
-      128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
-      ]> : vector<32xindex>
+  // CHECK-LABEL: test_load
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
       
-      %c17 = arith.constant 17: index
-      %mask = vector.create_mask %c17: vector<32xi1>
-
-      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-      %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
       
-      gpu.return %ld : vector<32xf32> 
-    }
+    gpu.return %ld : vector<32xf32> 
+  }
 
 //-----
 
@@ -219,17 +218,17 @@ gpu.module @test {
   // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
   gpu.func @test_prefetch(%src: ui64)  {
 
-      %cst = arith.constant dense<[
-      0,   8,  16,  24,  32,  40,  48,  56,
-      64,  72,  80,  88,  96, 104, 112, 120,
-      128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
-      ]> : vector<32xindex>
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
 
-      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
 
-      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-      gpu.return
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return
   }
 
 //-----
@@ -268,37 +267,36 @@ gpu.module @test {
 
   gpu.func @test_prefetch_load_store_update(%src: ui64)  {
 
-      %cst = arith.constant dense<[
-      0,   8,  16,  24,  32,  40,  48,  56,
-      64,  72,  80,  88,  96, 104, 112, 120,
-      128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
-      ]> : vector<32xindex>
-
-      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
 
-      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
    
-      %delta = arith.constant dense<[
-      32,   32,  32,  32,  32,  32,  32,  32,
-      32,   32,  32,  32,  32,  32,  32,  64,
-      128, 128, 128, 128, 128, 128, 128, 128,
-      128, 128, 128, 128, 128, 128, 128, 256 
-      ]> : vector<32xindex>
-      %new_tdesc = xegpu.update_offset %tdesc, %delta
-           : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
+    %delta = arith.constant dense<[
+    32,   32,  32,  32,  32,  32,  32,  32,
+    32,   32,  32,  32,  32,  32,  32,  64,
+    128, 128, 128, 128, 128, 128, 128, 128,
+    128, 128, 128, 128, 128, 128, 128, 256 
+    ]> : vector<32xindex>
+    %new_tdesc = xegpu.update_offset %tdesc, %delta
+              : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
  
-      %c17 = arith.constant 17: index
-      %mask = vector.create_mask %c17: vector<32xi1>
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
 
-      %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+    %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
 
-      %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
-      xegpu.store %st_vec, %tdesc, %mask: 
+    %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+    xegpu.store %st_vec, %tdesc, %mask: 
                  vector<32xf32>, 
                  !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
                  vector<32xi1>
   
-      gpu.return
-    }
+    gpu.return
+  }
 }

>From 2606a4bf20ab35ba9b46054c4dc4255d30619e75 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 13 Jun 2025 21:21:42 +0000
Subject: [PATCH 07/16] reject support for chunk_size

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 36 ++++++++++++++-----
 1 file changed, 28 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 900ade8c171d5..c8e1479f49f11 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -403,6 +403,11 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getType();
 
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
@@ -439,6 +444,11 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
     VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -480,6 +490,11 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
@@ -506,9 +521,12 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    VectorType maskTy;
-    if (op.getMask())
-      maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
@@ -524,13 +542,10 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Value> convertedMasks;
-    if (op.getMask()) {
-      SmallVector<Type> convertedMaskTypes =
+    SmallVector<Type> convertedMaskTypes =
           getUnrolledTypes(maskTy, *targetShape);
-      convertedMasks =
+    SmallVector<Value> convertedMasks =
           pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
-    }
 
     for (size_t i = 0; i < convertedValues.size(); ++i) {
       Value v = convertedValues[i];
@@ -553,6 +568,11 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();

>From a3e064d4f051a46dbee8960d86685aca0700bbcc Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 13 Jun 2025 21:30:40 +0000
Subject: [PATCH 08/16] small fixes

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 29 ++++++++-----------
 1 file changed, 12 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c8e1479f49f11..8a0048fbca8ed 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -403,10 +403,9 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
@@ -444,10 +443,9 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
@@ -490,10 +488,9 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
@@ -521,10 +518,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
@@ -543,9 +539,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
     SmallVector<Type> convertedMaskTypes =
-          getUnrolledTypes(maskTy, *targetShape);
+        getUnrolledTypes(maskTy, *targetShape);
     SmallVector<Value> convertedMasks =
-          pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
 
     for (size_t i = 0; i < convertedValues.size(); ++i) {
       Value v = convertedValues[i];
@@ -568,10 +564,9 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)

>From b3d59370fde678cff0c34130f2e5cc71900e0d66 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 14 Jun 2025 01:38:51 +0000
Subject: [PATCH 09/16] smal fix

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 8a0048fbca8ed..9c234c1e866b9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -415,6 +415,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
 
     TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
     VectorType indiceVecTy = indiceVec.getType();
+
     SmallVector<Type> convertedIndiceTypes =
         getUnrolledTypes(indiceVecTy, *targetShape);
     SmallVector<Value> convertedIndiceVec =

>From 248981b2f39a4c8b462735d6d894c93e5ff0c9da Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sun, 15 Jun 2025 04:05:18 +0000
Subject: [PATCH 10/16] add support for create_tdesc with chunk_size

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 53 +++++++++++++----
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 58 +++++++++++++------
 2 files changed, 81 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 9c234c1e866b9..8241866c005bb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -404,7 +404,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
     xegpu::TensorDescType tdescTy = op.getType();
 
     // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (tdescTy.getRank() > 2)
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -416,16 +416,47 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
     TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
     VectorType indiceVecTy = indiceVec.getType();
 
-    SmallVector<Type> convertedIndiceTypes =
-        getUnrolledTypes(indiceVecTy, *targetShape);
-    SmallVector<Value> convertedIndiceVec =
-        pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
-
+    SmallVector<Type> convertedIndiceTypes;
+    SmallVector<Value> convertedIndiceVec;
     SmallVector<Value> newOps;
-    for (auto indice : convertedIndiceVec) {
-      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+
+    if (tdescTy.getRank() == 2) {
+      SmallVector<int64_t> oneDShape(targetShape->begin(), targetShape->end() - 1);
+      convertedIndiceTypes = getUnrolledTypes(indiceVecTy, oneDShape);
+      convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, oneDShape, loc, rewriter);
+      // Assume tdescTy, targetShape, and convertedIndiceVec are defined
+      int64_t outerDim = tdescTy.getShape().back();
+      int64_t innerDim = targetShape->back();
+      int64_t numInnerLoops = outerDim / innerDim;
+
+      // Get element size in bytes
+      int64_t elemSize = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
+
+      for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
+        for (int64_t i = 0; i < numInnerLoops; ++i) {
+          // Compute the offset
+          Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * innerDim);
+          Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
+          Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
+
+          auto chunkSizeAttr = rewriter.getI64IntegerAttr(innerDim);
+              auto newOp = rewriter.create<xegpu::CreateDescOp>(
+              loc, newTdescTy, op.getSource(), offsetIndice);
+
+          newOps.push_back(newOp);
+        }
+      }
+    } else if (tdescTy.getRank() == 1) {
+      convertedIndiceTypes = getUnrolledTypes(indiceVecTy, *targetShape);
+      convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+      for (auto indice : convertedIndiceVec) {
+        auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
                                                         op.getSource(), indice);
-      newOps.push_back(newOp);
+        newOps.push_back(newOp);
+      }
+    } else {
+      // Unsupported rank for tensor descriptor
+      return failure();
     }
 
     Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
@@ -445,9 +476,9 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
     // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (tdescTy.getRank() > 2)
       return failure();
-
+    
     VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 57aaecbd7962f..758e774eb8e01 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -19,6 +19,12 @@ using namespace mlir::xegpu;
 
 namespace {
 
+
+#define DEBUG_TYPE "test-xegpu-unroll"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+
 struct TestXeGPUUnrollingPatterns
     : public PassWrapper<TestXeGPUUnrollingPatterns,
                          OperationPass<gpu::GPUModuleOp>> {
@@ -48,12 +54,13 @@ struct TestXeGPUUnrollingPatterns
     options.setNativeShapeFn(
         [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
           if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
-                  xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+                  xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+                  xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
             xegpu::TensorDescType tdescTy;
             if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
               tdescTy = createNdOp.getType();
-            } else if (auto updateNdOp =
-                           dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+            } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
               tdescTy = updateNdOp.getTensorDescType();
             } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
               tdescTy = prefetchNdOp.getTensorDescType();
@@ -61,20 +68,7 @@ struct TestXeGPUUnrollingPatterns
               tdescTy = loadNdOp.getTensorDescType();
             } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
               tdescTy = storeNdOp.getTensorDescType();
-            }
-
-            if (auto layout = tdescTy.getLayoutAttr()) {
-              auto inst_data = layout.getInstData();
-              if (inst_data && layout.isSgLayout())
-                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                            inst_data.asArrayRef().end());
-            }
-          }
-
-          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
-                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
-            xegpu::TensorDescType tdescTy;
-            if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+            } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
               tdescTy = createOp.getType();
             } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
               tdescTy = updateOp.getTensorDescType();
@@ -111,14 +105,40 @@ struct TestXeGPUUnrollingPatterns
             Attribute encoding = tdescTy.getEncoding();
             auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
                 tdescTy.getLayout());
+            
+            int64_t newChunkSize = 0;
+            auto instData = layout.getInstData();
+            if (!instData.empty())
+              newChunkSize = instData.asArrayRef().back();   
+            
             if (layout) {
               if (layout.getLaneLayout() == nullptr)
                 layout = xegpu::LayoutAttr();
               else
                 layout = layout.dropInstData();
             }
-            newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
-                                               layout);
+
+            SmallVector<NamedAttribute> attrs;
+            auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+            if (scatterAttr) {
+              int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+              
+              if (chunkSize > 1) {
+
+                auto chunkSizeAttr = mlir::IntegerAttr::get(
+                  mlir::IntegerType::get(ctx, 64), newChunkSize);
+
+                // To create a new attribute with a different chunk_size:
+                auto newEncoding = xegpu::ScatterTensorDescAttr::get(
+                    ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+
+                encoding = newEncoding;
+
+              }
+
+            }
+            newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout);
+
           } else {
             newTy = type.clone(tileShape, elemTy);
           }

>From 66ab4aa1c4e6b25cb26040a145236217ea0b114f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sun, 15 Jun 2025 16:08:36 +0000
Subject: [PATCH 11/16] add unrolling support for load/store/prefetch/update
 with chunk_size

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 131 ++++++++++++------
 1 file changed, 90 insertions(+), 41 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 8241866c005bb..5d0c1095bef4f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -421,32 +421,28 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
     SmallVector<Value> newOps;
 
     if (tdescTy.getRank() == 2) {
-      SmallVector<int64_t> oneDShape(targetShape->begin(), targetShape->end() - 1);
-      convertedIndiceTypes = getUnrolledTypes(indiceVecTy, oneDShape);
-      convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, oneDShape, loc, rewriter);
-      // Assume tdescTy, targetShape, and convertedIndiceVec are defined
-      int64_t outerDim = tdescTy.getShape().back();
-      int64_t innerDim = targetShape->back();
-      int64_t numInnerLoops = outerDim / innerDim;
+      SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
+      convertedIndiceTypes = getUnrolledTypes(indiceVecTy, shape1D);
+      convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, shape1D, loc, rewriter);
 
-      // Get element size in bytes
-      int64_t elemSize = tdescTy.getElementType().getIntOrFloatBitWidth() / 8;
+      int64_t wholeChunk = tdescTy.getShape().back();
+      int64_t blockedChunk = targetShape->back();
+      int64_t numInnerLoops = wholeChunk / blockedChunk;
 
       for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
         for (int64_t i = 0; i < numInnerLoops; ++i) {
           // Compute the offset
-          Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * innerDim);
+          Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunk);
           Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
           Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
 
-          auto chunkSizeAttr = rewriter.getI64IntegerAttr(innerDim);
-              auto newOp = rewriter.create<xegpu::CreateDescOp>(
+          auto newOp = rewriter.create<xegpu::CreateDescOp>(
               loc, newTdescTy, op.getSource(), offsetIndice);
 
           newOps.push_back(newOp);
         }
       }
-    } else if (tdescTy.getRank() == 1) {
+    } else {
       convertedIndiceTypes = getUnrolledTypes(indiceVecTy, *targetShape);
       convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
       for (auto indice : convertedIndiceVec) {
@@ -454,10 +450,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
                                                         op.getSource(), indice);
         newOps.push_back(newOp);
       }
-    } else {
-      // Unsupported rank for tensor descriptor
-      return failure();
-    }
+    } 
 
     Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
     rewriter.replaceOp(op, castOp);
@@ -493,10 +486,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Type> convertedMaskTypes =
-        getUnrolledTypes(maskTy, *targetShape);
-    SmallVector<Value> convertedMasks =
-        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    SmallVector<Type> convertedMaskTypes; 
+    SmallVector<Value> convertedMasks; 
+
+    if (tdescTy.getRank() == 2) {
+      convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
+      SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+      int64_t wholeChunk = tdescTy.getShape().back();
+      int64_t blockedChunk = targetShape->back();
+      int64_t numInnerLoops = wholeChunk / blockedChunk;
+
+      for (auto mask : convertedMasks1D) {
+        for (int64_t i = 0; i < numInnerLoops; ++i) {
+          convertedMasks.push_back(mask);
+        }
+      }
+      if (targetShape && targetShape->size() > 1) {
+        std::swap((*targetShape)[0], (*targetShape)[1]);
+        newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+      }
+    } else {
+      convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
+      convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    }
 
     SmallVector<Value> newOps;
     for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
@@ -505,9 +517,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
           op.getL2HintAttr(), op.getL3HintAttr());
       newOps.push_back(newOp);
     }
-
+    
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
-
     rewriter.replaceOp(op, castOp);
     return success();
   }
@@ -521,7 +532,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
     // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (tdescTy.getRank() > 2)
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -551,7 +562,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
     // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (tdescTy.getRank() > 2)
       return failure();
 
     VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
@@ -559,21 +570,40 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
-
-    SmallVector<Type> convertedValTypes =
-        getUnrolledTypes(valueTy, *targetShape);
+      
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
-
-    SmallVector<Value> convertedValues =
-        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Type> convertedMaskTypes =
-        getUnrolledTypes(maskTy, *targetShape);
-    SmallVector<Value> convertedMasks =
-        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Type> convertedMaskTypes; 
+    SmallVector<Value> convertedMasks; 
+
+    if (tdescTy.getRank() == 2) {
+      convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
+      SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+      int64_t wholeChunk = tdescTy.getShape().back();
+      int64_t blockedChunk = targetShape->back();
+      int64_t numInnerLoops = wholeChunk / blockedChunk;
+
+      for (auto mask : convertedMasks1D) {
+        for (int64_t i = 0; i < numInnerLoops; ++i) {
+          convertedMasks.push_back(mask);
+        }
+      }
+
+      std::swap((*targetShape)[0], (*targetShape)[1]);
+
+    } else {
+      convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
+      convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    }
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);   
 
     for (size_t i = 0; i < convertedValues.size(); ++i) {
       Value v = convertedValues[i];
@@ -597,7 +627,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
     // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (tdescTy.getRank() > 2)
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -611,17 +641,36 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
 
     TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
     VectorType offsetVecTy = offsetVec.getType();
-    SmallVector<Type> convertedOffsetTypes =
-        getUnrolledTypes(offsetVecTy, *targetShape);
-    SmallVector<Value> convertedOffsetVec =
-        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
-
+    SmallVector<Type> convertedOffsetTypes;
+    SmallVector<Value> convertedOffsetVec;
     SmallVector<Value> newOps;
+
+    if (tdescTy.getRank() == 2) {
+      SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
+      convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
+      SmallVector<Value> convertedOffsetVec1D = pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
+
+      int64_t wholeChunk = tdescTy.getShape().back();
+      int64_t blockedChunk = targetShape->back();
+      int64_t numInnerLoops = wholeChunk / blockedChunk;
+
+      for (auto offset : convertedOffsetVec1D) {
+        for (int64_t i = 0; i < numInnerLoops; ++i) {
+          convertedOffsetVec.push_back(offset);
+        }
+      }
+
+    } else {
+      convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
+      convertedOffsetVec = pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+    } 
+
     for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
       auto newOp =
           rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
       newOps.push_back(newOp);
     }
+
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
     rewriter.replaceOp(op, castOp);
     return success();

>From 7cbeebba475cdf33f8ba0d4256eeb248f53a056f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 16 Jun 2025 19:30:41 +0000
Subject: [PATCH 12/16] small bug fixes

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 158 ++++++++++++------
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  16 +-
 2 files changed, 110 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 5d0c1095bef4f..33ebd6219abd7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -267,7 +267,7 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     return success();
   }
 };
-
+/*
 struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
   using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
   LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
@@ -298,6 +298,49 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
     return success();
   }
 };
+*/
+
+struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
+  using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType valueTy = op.getValueType();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    LDBG("UnrollStoreNdOp: targetShape present? " << (targetShape.has_value() ? "yes" : "no"));
+    if (!targetShape)
+      return failure();
+
+    LDBG("targetShape: ");
+    for (auto v : *targetShape) LDBG("  " << v);
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    LDBG("convertedValTypes size: " << convertedValTypes.size());
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    LDBG("convertedTdescTypes size: " << convertedTdescTypes.size());
+
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    LDBG("convertedValues size: " << convertedValues.size());
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+    LDBG("convertedTdescs size: " << convertedTdescs.size());
+
+    for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
+      LDBG("Creating StoreNdOp with value: " << v << ", tdesc: " << t);
+      rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
+                                        op.getL2HintAttr(), op.getL3HintAttr());
+    }
+
+    LDBG("Erasing original StoreNdOp: " << op);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
 
 struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
   using UnrollPattern<xegpu::DpasOp>::UnrollPattern;
@@ -402,37 +445,40 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getType();
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+    VectorType indiceVecTy = indiceVec.getType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 2)
+    if (!tdescTy.isScattered())
       return failure();
-
+    
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
+ 
+    SmallVector<int64_t> targetIndiceShape(*targetShape);
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+    // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
+    if (originalChunkSize > 1)
+      targetIndiceShape.pop_back();
 
     auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
-    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
-    VectorType indiceVecTy = indiceVec.getType();
-
-    SmallVector<Type> convertedIndiceTypes;
-    SmallVector<Value> convertedIndiceVec;
+    SmallVector<Type> convertedIndiceTypes = 
+        getUnrolledTypes(indiceVecTy, targetIndiceShape);
+    SmallVector<Value> convertedIndiceVec = 
+        pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
+    
     SmallVector<Value> newOps;
 
-    if (tdescTy.getRank() == 2) {
-      SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
-      convertedIndiceTypes = getUnrolledTypes(indiceVecTy, shape1D);
-      convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, shape1D, loc, rewriter);
-
-      int64_t wholeChunk = tdescTy.getShape().back();
-      int64_t blockedChunk = targetShape->back();
-      int64_t numInnerLoops = wholeChunk / blockedChunk;
+    // more indices is need when chunkSize > 1. Since a big load from one
+    // address could be break into multiple small loads.
+    if (originalChunkSize > 1) {
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize/blockedChunkSize;
 
       for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
-        for (int64_t i = 0; i < numInnerLoops; ++i) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
           // Compute the offset
-          Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunk);
+          Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunkSize);
           Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
           Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
 
@@ -443,8 +489,6 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
         }
       }
     } else {
-      convertedIndiceTypes = getUnrolledTypes(indiceVecTy, *targetShape);
-      convertedIndiceVec = pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
       for (auto indice : convertedIndiceVec) {
         auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
                                                         op.getSource(), indice);
@@ -468,15 +512,17 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 2)
+    if (!tdescTy.isScattered())
       return failure();
     
-    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
+ 
+    SmallVector<int64_t> targetMaskShape(*targetShape);
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
     Type elemTy = tdescTy.getElementType();
     VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
@@ -489,25 +535,26 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Type> convertedMaskTypes; 
     SmallVector<Value> convertedMasks; 
 
-    if (tdescTy.getRank() == 2) {
-      convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
-      SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
-      int64_t wholeChunk = tdescTy.getShape().back();
-      int64_t blockedChunk = targetShape->back();
-      int64_t numInnerLoops = wholeChunk / blockedChunk;
+    if (originalChunkSize > 1) {
+      targetMaskShape.pop_back();
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize/blockedChunkSize;
 
       for (auto mask : convertedMasks1D) {
-        for (int64_t i = 0; i < numInnerLoops; ++i) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
           convertedMasks.push_back(mask);
         }
       }
+      // This is to handle the transpose effect when chunkSize > 1. 
       if (targetShape && targetShape->size() > 1) {
         std::swap((*targetShape)[0], (*targetShape)[1]);
         newValueTy = valueTy.cloneWith(*targetShape, elemTy);
       }
     } else {
-      convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
-      convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
     }
 
     SmallVector<Value> newOps;
@@ -561,38 +608,38 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 2)
+    if (!tdescTy.isScattered())
       return failure();
 
-    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
-      
+
+    SmallVector<int64_t> targetIndiceShape(*targetShape);
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+     
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-
     SmallVector<Type> convertedMaskTypes; 
     SmallVector<Value> convertedMasks; 
 
-    if (tdescTy.getRank() == 2) {
+    if (originalChunkSize > 1) {
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize/blockedChunkSize;
       convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
       SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
-      int64_t wholeChunk = tdescTy.getShape().back();
-      int64_t blockedChunk = targetShape->back();
-      int64_t numInnerLoops = wholeChunk / blockedChunk;
 
       for (auto mask : convertedMasks1D) {
-        for (int64_t i = 0; i < numInnerLoops; ++i) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
           convertedMasks.push_back(mask);
         }
       }
-
+      // This is to handle the transpose effect when chunkSize > 1. 
       std::swap((*targetShape)[0], (*targetShape)[1]);
 
     } else {
@@ -626,8 +673,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 2)
+    if (tdescTy.getRank() >2)
+      return failure();
+
+    if (!tdescTy.isScattered())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -644,18 +693,17 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     SmallVector<Type> convertedOffsetTypes;
     SmallVector<Value> convertedOffsetVec;
     SmallVector<Value> newOps;
-
-    if (tdescTy.getRank() == 2) {
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+    if (originalChunkSize > 1) {
       SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
       convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
       SmallVector<Value> convertedOffsetVec1D = pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
 
-      int64_t wholeChunk = tdescTy.getShape().back();
-      int64_t blockedChunk = targetShape->back();
-      int64_t numInnerLoops = wholeChunk / blockedChunk;
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize/blockedChunkSize;
 
       for (auto offset : convertedOffsetVec1D) {
-        for (int64_t i = 0; i < numInnerLoops; ++i) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
           convertedOffsetVec.push_back(offset);
         }
       }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 758e774eb8e01..78e2b7601a1bb 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -105,12 +105,7 @@ struct TestXeGPUUnrollingPatterns
             Attribute encoding = tdescTy.getEncoding();
             auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
                 tdescTy.getLayout());
-            
-            int64_t newChunkSize = 0;
-            auto instData = layout.getInstData();
-            if (!instData.empty())
-              newChunkSize = instData.asArrayRef().back();   
-            
+           
             if (layout) {
               if (layout.getLaneLayout() == nullptr)
                 layout = xegpu::LayoutAttr();
@@ -118,12 +113,15 @@ struct TestXeGPUUnrollingPatterns
                 layout = layout.dropInstData();
             }
 
-            SmallVector<NamedAttribute> attrs;
-            auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
-            if (scatterAttr) {
+            if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
+              auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
               int64_t chunkSize = scatterAttr.getChunkSize().getInt();
               
               if (chunkSize > 1) {
+                int64_t newChunkSize = chunkSize;
+                auto instData = layout.getInstData();
+                if (!instData.empty())
+                  newChunkSize = instData.asArrayRef().back();   
 
                 auto chunkSizeAttr = mlir::IntegerAttr::get(
                   mlir::IntegerType::get(ctx, 64), newChunkSize);

>From 47fe1438afdb37724b94e54383721dac608a35f5 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 16 Jun 2025 20:56:01 +0000
Subject: [PATCH 13/16] add tests

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  45 +------
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 110 ++++++++++++++++++
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  14 +--
 3 files changed, 118 insertions(+), 51 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 33ebd6219abd7..18b6c38b7f5e4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -267,38 +267,6 @@ struct UnrollLoadNdOp : public UnrollPattern<xegpu::LoadNdOp> {
     return success();
   }
 };
-/*
-struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
-  using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
-  LogicalResult matchAndRewrite(xegpu::StoreNdOp op,
-                                PatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
-    VectorType valueTy = op.getValueType();
-    xegpu::TensorDescType tdescTy = op.getTensorDescType();
-
-    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    if (!targetShape)
-      return failure();
-
-    SmallVector<Type> convertedValTypes =
-        getUnrolledTypes(valueTy, *targetShape);
-    SmallVector<Type> convertedTdescTypes =
-        getUnrolledTypes(tdescTy, *targetShape);
-
-    SmallVector<Value> convertedValues =
-        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
-    SmallVector<Value> convertedTdescs = pack(
-        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
-
-    for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
-      rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
-                                        op.getL2HintAttr(), op.getL3HintAttr());
-
-    rewriter.eraseOp(op);
-    return success();
-  }
-};
-*/
 
 struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
   using UnrollPattern<xegpu::StoreNdOp>::UnrollPattern;
@@ -309,34 +277,23 @@ struct UnrollStoreNdOp : public UnrollPattern<xegpu::StoreNdOp> {
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
-    LDBG("UnrollStoreNdOp: targetShape present? " << (targetShape.has_value() ? "yes" : "no"));
     if (!targetShape)
       return failure();
 
-    LDBG("targetShape: ");
-    for (auto v : *targetShape) LDBG("  " << v);
-
     SmallVector<Type> convertedValTypes =
         getUnrolledTypes(valueTy, *targetShape);
-    LDBG("convertedValTypes size: " << convertedValTypes.size());
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
-    LDBG("convertedTdescTypes size: " << convertedTdescTypes.size());
 
     SmallVector<Value> convertedValues =
         pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
-    LDBG("convertedValues size: " << convertedValues.size());
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
-    LDBG("convertedTdescs size: " << convertedTdescs.size());
 
-    for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs)) {
-      LDBG("Creating StoreNdOp with value: " << v << ", tdesc: " << t);
+    for (auto [v, t] : llvm::zip(convertedValues, convertedTdescs))
       rewriter.create<xegpu::StoreNdOp>(loc, v, t, op.getL1HintAttr(),
                                         op.getL2HintAttr(), op.getL3HintAttr());
-    }
 
-    LDBG("Erasing original StoreNdOp: " << op);
     rewriter.eraseOp(op);
     return success();
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 52ec3b856da49..0d3bd4fdb311e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -299,4 +299,114 @@ gpu.module @test {
   
     gpu.return
   }
+
+//-----
+  // CHECK-LABEL: test_create_tdesc_step_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4 : i64>>
+  gpu.func @test_create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
+    %step = arith.constant dense<8> : vector<32xindex>
+    %seq = vector.step  : vector<32xindex>
+    %cst = arith.muli %seq, %step : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
+  }
+
+//-----
+  // CHECK-LABEL: test_create_tdesc_step_chunk2
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  gpu.func @test_create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+    %step = arith.constant dense<8> : vector<32xindex>
+    %seq = vector.step  : vector<32xindex>
+    %cst = arith.muli %seq, %step : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+  }
+
+//-----
+  // CHECK-LABEL: test_load_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-4: xegpu.load  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
+
+  gpu.func @test_load_chunk(%src: ui64) -> vector<4x32xf32> {
+    %cst = arith.constant dense<[
+        0,   8,  16,  24,  32,  40,  48,  56,
+        64,  72,  80,  88,  96, 104, 112, 120,
+        128, 136, 144, 152, 160, 168, 176, 184,
+        192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> 
+    %ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
+    
+    gpu.return %ld : vector<4x32xf32> 
+   }
+
+//-----
+  // CHECK-LABEL: test_store_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} :  ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-4: xegpu.store  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+  gpu.func @test_store_chunk(%src: ui64) {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %st_vec = arith.constant dense<1023.>: vector<4x32xf32>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: vector<4x32xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16,2]>>, vector<32xi1>
+    
+    gpu.return
+  }
+
+//-----
+  // CHECK-LABEL: test_prefetch_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  gpu.func @test_prefetch_chunk(%src: ui64)  {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    
+    gpu.return
+  }
+
+//-----
+  // CHECK-LABEL: test_update_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
+  gpu.func @test_update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    %delta = arith.constant dense<32>: vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+
+    %new_tdesc = xegpu.update_offset %tdesc, %delta
+        : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
+
+    gpu.return %new_tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+  }  
 }
+
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 78e2b7601a1bb..8be37609fc806 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -106,13 +106,6 @@ struct TestXeGPUUnrollingPatterns
             auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
                 tdescTy.getLayout());
            
-            if (layout) {
-              if (layout.getLaneLayout() == nullptr)
-                layout = xegpu::LayoutAttr();
-              else
-                layout = layout.dropInstData();
-            }
-
             if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
               auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
               int64_t chunkSize = scatterAttr.getChunkSize().getInt();
@@ -135,6 +128,13 @@ struct TestXeGPUUnrollingPatterns
               }
 
             }
+            if (layout) {
+              if (layout.getLaneLayout() == nullptr)
+                layout = xegpu::LayoutAttr();
+              else
+                layout = layout.dropInstData();
+            }
+
             newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout);
 
           } else {

>From a569bc8acd38337a3f52fbb21ade42e9595e6586 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 16 Jun 2025 21:05:10 +0000
Subject: [PATCH 14/16] clang format fix

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 80 +++++++++++--------
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 21 +++--
 2 files changed, 55 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 18b6c38b7f5e4..ed8de1639c35e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -407,11 +407,11 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
 
     if (!tdescTy.isScattered())
       return failure();
-    
+
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
- 
+
     SmallVector<int64_t> targetIndiceShape(*targetShape);
     int64_t originalChunkSize = tdescTy.getChunkSize();
     // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
@@ -419,25 +419,28 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
       targetIndiceShape.pop_back();
 
     auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-    SmallVector<Type> convertedIndiceTypes = 
+    SmallVector<Type> convertedIndiceTypes =
         getUnrolledTypes(indiceVecTy, targetIndiceShape);
-    SmallVector<Value> convertedIndiceVec = 
+    SmallVector<Value> convertedIndiceVec =
         pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
-    
+
     SmallVector<Value> newOps;
 
     // more indices is need when chunkSize > 1. Since a big load from one
     // address could be break into multiple small loads.
     if (originalChunkSize > 1) {
       int64_t blockedChunkSize = targetShape->back();
-      int64_t numNewChunks = originalChunkSize/blockedChunkSize;
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
 
-      for (auto [indice, indiceType] : llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
+      for (auto [indice, indiceType] :
+           llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
         for (int64_t i = 0; i < numNewChunks; ++i) {
           // Compute the offset
-          Value inc = rewriter.create<arith::ConstantIndexOp>(loc, i * blockedChunkSize);
+          Value inc = rewriter.create<arith::ConstantIndexOp>(
+              loc, i * blockedChunkSize);
           Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
-          Value offsetIndice = rewriter.create<arith::AddIOp>(loc, indice, incVec);
+          Value offsetIndice =
+              rewriter.create<arith::AddIOp>(loc, indice, incVec);
 
           auto newOp = rewriter.create<xegpu::CreateDescOp>(
               loc, newTdescTy, op.getSource(), offsetIndice);
@@ -447,11 +450,11 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
       }
     } else {
       for (auto indice : convertedIndiceVec) {
-        auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
-                                                        op.getSource(), indice);
+        auto newOp = rewriter.create<xegpu::CreateDescOp>(
+            loc, newTdescTy, op.getSource(), indice);
         newOps.push_back(newOp);
       }
-    } 
+    }
 
     Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
     rewriter.replaceOp(op, castOp);
@@ -471,11 +474,11 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
 
     if (!tdescTy.isScattered())
       return failure();
-    
+
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
- 
+
     SmallVector<int64_t> targetMaskShape(*targetShape);
     int64_t originalChunkSize = tdescTy.getChunkSize();
 
@@ -489,29 +492,31 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Type> convertedMaskTypes; 
-    SmallVector<Value> convertedMasks; 
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
 
     if (originalChunkSize > 1) {
       targetMaskShape.pop_back();
       convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
-      SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+      SmallVector<Value> convertedMasks1D = pack(
+          op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
       int64_t blockedChunkSize = targetShape->back();
-      int64_t numNewChunks = originalChunkSize/blockedChunkSize;
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
 
       for (auto mask : convertedMasks1D) {
         for (int64_t i = 0; i < numNewChunks; ++i) {
           convertedMasks.push_back(mask);
         }
       }
-      // This is to handle the transpose effect when chunkSize > 1. 
+      // This is to handle the transpose effect when chunkSize > 1.
       if (targetShape && targetShape->size() > 1) {
         std::swap((*targetShape)[0], (*targetShape)[1]);
         newValueTy = valueTy.cloneWith(*targetShape, elemTy);
       }
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
-      convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+      convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
+                            loc, rewriter);
     }
 
     SmallVector<Value> newOps;
@@ -521,7 +526,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
           op.getL2HintAttr(), op.getL3HintAttr());
       newOps.push_back(newOp);
     }
-    
+
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
     rewriter.replaceOp(op, castOp);
     return success();
@@ -576,38 +581,40 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     int64_t originalChunkSize = tdescTy.getChunkSize();
 
     VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-     
+
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Type> convertedMaskTypes; 
-    SmallVector<Value> convertedMasks; 
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
 
     if (originalChunkSize > 1) {
       int64_t blockedChunkSize = targetShape->back();
-      int64_t numNewChunks = originalChunkSize/blockedChunkSize;
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
       convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
-      SmallVector<Value> convertedMasks1D = pack(op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+      SmallVector<Value> convertedMasks1D = pack(
+          op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
 
       for (auto mask : convertedMasks1D) {
         for (int64_t i = 0; i < numNewChunks; ++i) {
           convertedMasks.push_back(mask);
         }
       }
-      // This is to handle the transpose effect when chunkSize > 1. 
+      // This is to handle the transpose effect when chunkSize > 1.
       std::swap((*targetShape)[0], (*targetShape)[1]);
 
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
-      convertedMasks = pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+      convertedMasks =
+          pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
     }
 
     SmallVector<Type> convertedValTypes =
         getUnrolledTypes(valueTy, *targetShape);
     SmallVector<Value> convertedValues =
-        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);   
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
 
     for (size_t i = 0; i < convertedValues.size(); ++i) {
       Value v = convertedValues[i];
@@ -630,7 +637,7 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    if (tdescTy.getRank() >2)
+    if (tdescTy.getRank() > 2)
       return failure();
 
     if (!tdescTy.isScattered())
@@ -652,12 +659,14 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     SmallVector<Value> newOps;
     int64_t originalChunkSize = tdescTy.getChunkSize();
     if (originalChunkSize > 1) {
-      SmallVector<int64_t> shape1D(targetShape->begin(), targetShape->end() - 1);
+      SmallVector<int64_t> shape1D(targetShape->begin(),
+                                   targetShape->end() - 1);
       convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
-      SmallVector<Value> convertedOffsetVec1D = pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
+      SmallVector<Value> convertedOffsetVec1D =
+          pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
 
       int64_t blockedChunkSize = targetShape->back();
-      int64_t numNewChunks = originalChunkSize/blockedChunkSize;
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
 
       for (auto offset : convertedOffsetVec1D) {
         for (int64_t i = 0; i < numNewChunks; ++i) {
@@ -667,8 +676,9 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
 
     } else {
       convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
-      convertedOffsetVec = pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
-    } 
+      convertedOffsetVec =
+          pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+    }
 
     for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
       auto newOp =
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 8be37609fc806..4fe81a01d3544 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -19,12 +19,10 @@ using namespace mlir::xegpu;
 
 namespace {
 
-
 #define DEBUG_TYPE "test-xegpu-unroll"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
-
 struct TestXeGPUUnrollingPatterns
     : public PassWrapper<TestXeGPUUnrollingPatterns,
                          OperationPass<gpu::GPUModuleOp>> {
@@ -60,7 +58,8 @@ struct TestXeGPUUnrollingPatterns
             xegpu::TensorDescType tdescTy;
             if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
               tdescTy = createNdOp.getType();
-            } else if (auto updateNdOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
+            } else if (auto updateNdOp =
+                           dyn_cast<xegpu::UpdateNdOffsetOp>(op)) {
               tdescTy = updateNdOp.getTensorDescType();
             } else if (auto prefetchNdOp = dyn_cast<xegpu::PrefetchNdOp>(op)) {
               tdescTy = prefetchNdOp.getTensorDescType();
@@ -105,28 +104,27 @@ struct TestXeGPUUnrollingPatterns
             Attribute encoding = tdescTy.getEncoding();
             auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
                 tdescTy.getLayout());
-           
+
             if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
-              auto scatterAttr = mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+              auto scatterAttr =
+                  mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
               int64_t chunkSize = scatterAttr.getChunkSize().getInt();
-              
+
               if (chunkSize > 1) {
                 int64_t newChunkSize = chunkSize;
                 auto instData = layout.getInstData();
                 if (!instData.empty())
-                  newChunkSize = instData.asArrayRef().back();   
+                  newChunkSize = instData.asArrayRef().back();
 
                 auto chunkSizeAttr = mlir::IntegerAttr::get(
-                  mlir::IntegerType::get(ctx, 64), newChunkSize);
+                    mlir::IntegerType::get(ctx, 64), newChunkSize);
 
                 // To create a new attribute with a different chunk_size:
                 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
                     ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
 
                 encoding = newEncoding;
-
               }
-
             }
             if (layout) {
               if (layout.getLaneLayout() == nullptr)
@@ -135,7 +133,8 @@ struct TestXeGPUUnrollingPatterns
                 layout = layout.dropInstData();
             }
 
-            newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, layout);
+            newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
+                                               layout);
 
           } else {
             newTy = type.clone(tileShape, elemTy);

>From 41b839de279e94689fe628e45238ae3e6d56f5a3 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 16 Jun 2025 22:34:05 +0000
Subject: [PATCH 15/16] remove 1 complex test

---
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 45 -------------------
 1 file changed, 45 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 0d3bd4fdb311e..1726b654c4746 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -255,51 +255,6 @@ gpu.module @test {
     gpu.return
   }
 
-//-----
-
-  // CHECK-LABEL: test_prefetch_load_store_update
-  // CHECK-SAME: [[arg0:%.+]]: ui64
-  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-   // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
-   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
-
-  gpu.func @test_prefetch_load_store_update(%src: ui64)  {
-
-    %cst = arith.constant dense<[
-    0,   8,  16,  24,  32,  40,  48,  56,
-    64,  72,  80,  88,  96, 104, 112, 120,
-    128, 136, 144, 152, 160, 168, 176, 184,
-    192, 200, 208, 216, 224, 232, 240, 248 
-    ]> : vector<32xindex>
-
-    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-   
-    %delta = arith.constant dense<[
-    32,   32,  32,  32,  32,  32,  32,  32,
-    32,   32,  32,  32,  32,  32,  32,  64,
-    128, 128, 128, 128, 128, 128, 128, 128,
-    128, 128, 128, 128, 128, 128, 128, 256 
-    ]> : vector<32xindex>
-    %new_tdesc = xegpu.update_offset %tdesc, %delta
-              : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
- 
-    %c17 = arith.constant 17: index
-    %mask = vector.create_mask %c17: vector<32xi1>
-
-    %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
-
-    %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
-    xegpu.store %st_vec, %tdesc, %mask: 
-                 vector<32xf32>, 
-                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
-                 vector<32xi1>
-  
-    gpu.return
-  }
-
 //-----
   // CHECK-LABEL: test_create_tdesc_step_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64

>From 3c3023c36e8e73b6eb1b3feac575f302fed6a09c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 17 Jun 2025 21:50:59 +0000
Subject: [PATCH 16/16] address review feedback

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  14 +--
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 109 ++++++++++--------
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  31 +----
 3 files changed, 73 insertions(+), 81 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index faeec1b27501c..91e0f8196f898 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -426,7 +426,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
 
     SmallVector<Value> newOps;
 
-    // more indices is need when chunkSize > 1. Since a big load from one
+    // More indices is need when chunkSize > 1. Since a big load from one
     // address could be break into multiple small loads.
     if (originalChunkSize > 1) {
       int64_t blockedChunkSize = targetShape->back();
@@ -504,15 +504,12 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
       int64_t numNewChunks = originalChunkSize / blockedChunkSize;
 
       for (auto mask : convertedMasks1D) {
-        for (int64_t i = 0; i < numNewChunks; ++i) {
+        for (int64_t i = 0; i < numNewChunks; ++i) 
           convertedMasks.push_back(mask);
-        }
       }
       // This is to handle the transpose effect when chunkSize > 1.
-      if (targetShape && targetShape->size() > 1) {
-        std::swap((*targetShape)[0], (*targetShape)[1]);
-        newValueTy = valueTy.cloneWith(*targetShape, elemTy);
-      }
+      std::swap((*targetShape)[0], (*targetShape)[1]);
+      newValueTy = valueTy.cloneWith(*targetShape, elemTy);
     } else {
       convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
       convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
@@ -540,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 2)
+    if (!tdescTy.isScattered())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 1726b654c4746..41414d802f212 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -2,7 +2,7 @@
 
 gpu.module @test {
 
-  // CHECK-LABEL: test_create_nd_tdesc
+  // CHECK-LABEL: create_nd_tdesc
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
@@ -10,31 +10,31 @@ gpu.module @test {
   // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
   // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
-  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+  gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
   }
 
   //-----
 
-  // CHECK-LABEL: test_create_nd_tdesc_1d
+  // CHECK-LABEL: create_nd_tdesc_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
   // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
   // CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {__xegpu_blocking_tile_shape__ = array<i64: 16>, __xegpu_blocking_unpack__}
-  gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+  gpu.func @create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
     gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
   }
 
   //-----
 
-  // CHECK-LABEL: test_update_nd_tdesc
+  // CHECK-LABEL: update_nd_tdesc
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32>
-  gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+  gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -42,11 +42,11 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_update_nd_tdesc_1d
+  // CHECK-LABEL: update_nd_tdesc_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32>
-  gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+  gpu.func @update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
     %update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
     gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
@@ -54,11 +54,11 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_prefetch_nd_tdesc
+  // CHECK-LABEL: prefetch_nd_tdesc
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32>
-  gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+  gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     gpu.return
@@ -66,23 +66,23 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_prefetch_nd_tdesc_1d
+  // CHECK-LABEL: prefetch_nd_tdesc_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32>
-  gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
+  gpu.func @prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
     xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
     gpu.return
   }
 
   //-----
-  // CHECK-LABEL: test_load_nd
+  // CHECK-LABEL: load_nd
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
   // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
-  gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
+  gpu.func @load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
     gpu.return %ld : vector<24x32xf32>
@@ -90,12 +90,12 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_load_nd_1d
+  // CHECK-LABEL: load_nd_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
   // CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32>
-  gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
+  gpu.func @load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
     %data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>> -> vector<64xf32>
     gpu.return %data : vector<64xf32>
@@ -103,11 +103,11 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_store_nd
+  // CHECK-LABEL: store_nd
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK-COUNT-6: xegpu.store_nd {{.*}}  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+  gpu.func @store_nd(%src: memref<24x32xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     %data = arith.constant dense<9.0> : vector<24x32xf32>
     xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -116,11 +116,11 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_store_nd_1d
+  // CHECK-LABEL: store_nd_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK-COUNT-4: xegpu.store_nd {{.*}}  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  gpu.func @test_store_nd_1d(%src: memref<64xf32>) {
+  gpu.func @store_nd_1d(%src: memref<64xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
     %data = arith.constant dense<9.0> : vector<64xf32>
     xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
@@ -129,7 +129,7 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_createNd_loadNd_storeNd
+  // CHECK-LABEL: createNd_loadNd_storeNd
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   //CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
@@ -137,7 +137,7 @@ gpu.module @test {
   //CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
   //CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
   //CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
+  gpu.func @createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     %data = arith.constant dense<9.0> : vector<24x32xf32>
     %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
@@ -148,23 +148,23 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_dpas
+  // CHECK-LABEL: dpas
   // CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16>
   //CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16>
   //CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16>
   //CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32>
   //CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
-  gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
+  gpu.func @dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
     %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
     gpu.return %c : vector<32x32xf32>
   }
 
 //-----
 
-  // CHECK-LABEL: test_create_tdesc_vec
+  // CHECK-LABEL: create_tdesc_vec
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+  gpu.func @create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
     %cst = arith.constant dense<[
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
@@ -177,10 +177,10 @@ gpu.module @test {
 
 //-----
 
-  // CHECK-LABEL: test_create_tdesc_step
+  // CHECK-LABEL: create_tdesc_step
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+  gpu.func @create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
     %step = arith.constant dense<8> : vector<32xindex>
     %seq = vector.step  : vector<32xindex>
     %cst = arith.muli %seq, %step : vector<32xindex>
@@ -190,11 +190,11 @@ gpu.module @test {
 
 //-----
 
-  // CHECK-LABEL: test_load
+  // CHECK-LABEL: load
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-  gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+  gpu.func @load(%src: ui64) -> vector<32xf32> {
     %cst = arith.constant dense<[
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
@@ -212,11 +212,11 @@ gpu.module @test {
 
 //-----
 
-  // CHECK-LABEL: test_prefetch
+  // CHECK-LABEL: prefetch
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
   // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  gpu.func @test_prefetch(%src: ui64)  {
+  gpu.func @prefetch(%src: ui64)  {
 
     %cst = arith.constant dense<[
     0,   8,  16,  24,  32,  40,  48,  56,
@@ -233,11 +233,11 @@ gpu.module @test {
 
 //-----
 
-  // CHECK-LABEL: test_store
+  // CHECK-LABEL: store
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
   // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
-  gpu.func @test_store(%src: ui64) {
+  gpu.func @store(%src: ui64) {
     %cst = arith.constant dense<[
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
@@ -256,10 +256,10 @@ gpu.module @test {
   }
 
 //-----
-  // CHECK-LABEL: test_create_tdesc_step_chunk
+  // CHECK-LABEL: create_tdesc_step_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4 : i64>>
-  gpu.func @test_create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
+  gpu.func @create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
     %step = arith.constant dense<8> : vector<32xindex>
     %seq = vector.step  : vector<32xindex>
     %cst = arith.muli %seq, %step : vector<32xindex>
@@ -268,10 +268,10 @@ gpu.module @test {
   }
 
 //-----
-  // CHECK-LABEL: test_create_tdesc_step_chunk2
+  // CHECK-LABEL: create_tdesc_step_chunk2
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  gpu.func @test_create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+  gpu.func @create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
     %step = arith.constant dense<8> : vector<32xindex>
     %seq = vector.step  : vector<32xindex>
     %cst = arith.muli %seq, %step : vector<32xindex>
@@ -279,13 +279,30 @@ gpu.module @test {
     gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
   }
 
+// CHECK-LABEL: create_tdesc_step_chunk3
+  // CHECK-SAME: [[arg0:%.+]]: ui64  
+  // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+  // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+  // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+  // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+    gpu.func @create_tdesc_step_chunk3(%src: ui64) -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>> {
+    %step = arith.constant dense<8> : vector<16xindex>
+    %seq = vector.step  : vector<16xindex>
+    %cst = arith.muli %seq, %step : vector<16xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32,  #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<16x8xf32,  #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
+  }
+
 //-----
-  // CHECK-LABEL: test_load_chunk
+  // CHECK-LABEL: load_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   // CHECK-COUNT-4: xegpu.load  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
 
-  gpu.func @test_load_chunk(%src: ui64) -> vector<4x32xf32> {
+  gpu.func @load_chunk(%src: ui64) -> vector<4x32xf32> {
     %cst = arith.constant dense<[
         0,   8,  16,  24,  32,  40,  48,  56,
         64,  72,  80,  88,  96, 104, 112, 120,
@@ -303,11 +320,11 @@ gpu.module @test {
    }
 
 //-----
-  // CHECK-LABEL: test_store_chunk
+  // CHECK-LABEL: store_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} :  ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   // CHECK-COUNT-4: xegpu.store  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
-  gpu.func @test_store_chunk(%src: ui64) {
+  gpu.func @store_chunk(%src: ui64) {
     %cst = arith.constant dense<[
       0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
@@ -326,11 +343,11 @@ gpu.module @test {
   }
 
 //-----
-  // CHECK-LABEL: test_prefetch_chunk
+  // CHECK-LABEL: prefetch_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  gpu.func @test_prefetch_chunk(%src: ui64)  {
+  gpu.func @prefetch_chunk(%src: ui64)  {
     %cst = arith.constant dense<[
       0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
@@ -344,11 +361,11 @@ gpu.module @test {
   }
 
 //-----
-  // CHECK-LABEL: test_update_chunk
+  // CHECK-LABEL: update_chunk
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
-  gpu.func @test_update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+  gpu.func @update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
     %cst = arith.constant dense<[
       0,   8,  16,  24,  32,  40,  48,  56,
       64,  72,  80,  88,  96, 104, 112, 120,
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 8110ca237dcd3..4400d6d9625f7 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -87,29 +87,6 @@ struct TestXeGPUUnrollingPatterns
             }
           }
 
-          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
-                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
-            xegpu::TensorDescType tdescTy;
-            if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
-              tdescTy = createOp.getType();
-            } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
-              tdescTy = updateOp.getTensorDescType();
-            } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
-              tdescTy = prefetchOp.getTensorDescType();
-            } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
-              tdescTy = loadOp.getTensorDescType();
-            } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
-              tdescTy = storeOp.getTensorDescType();
-            }
-
-            if (auto layout = tdescTy.getLayoutAttr()) {
-              auto inst_data = layout.getInstData();
-              if (inst_data && layout.isSgLayout())
-                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                            inst_data.asArrayRef().end());
-            }
-          }
-
           if (isa<xegpu::DpasOp>(op))
             return SmallVector<int64_t>{8, 16, 16};
 
@@ -128,19 +105,21 @@ struct TestXeGPUUnrollingPatterns
             auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
                 tdescTy.getLayout());
 
+            // If the encoding is a ScatterTensorDescAttr, we need to
+            // potentially adjust the chunk size based on the inst_data.
             if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
               auto scatterAttr =
                   mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
               int64_t chunkSize = scatterAttr.getChunkSize().getInt();
 
               if (chunkSize > 1) {
-                int64_t newChunkSize = chunkSize;
+                int64_t blockedChunkSize = chunkSize;
                 auto instData = layout.getInstData();
                 if (!instData.empty())
-                  newChunkSize = instData.asArrayRef().back();
+                  blockedChunkSize = instData.asArrayRef().back();
 
                 auto chunkSizeAttr = mlir::IntegerAttr::get(
-                    mlir::IntegerType::get(ctx, 64), newChunkSize);
+                    mlir::IntegerType::get(ctx, 64), blockedChunkSize);
 
                 // To create a new attribute with a different chunk_size:
                 auto newEncoding = xegpu::ScatterTensorDescAttr::get(



More information about the Mlir-commits mailing list