[Mlir-commits] [mlir] f25f2f7 - [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (#144447)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 17 15:46:38 PDT 2025


Author: Jianhui Li
Date: 2025-06-17T17:46:35-05:00
New Revision: f25f2f7de4f8264d89ba3c4dc9daddb10a90c13f

URL: https://github.com/llvm/llvm-project/commit/f25f2f7de4f8264d89ba3c4dc9daddb10a90c13f
DIFF: https://github.com/llvm/llvm-project/commit/f25f2f7de4f8264d89ba3c4dc9daddb10a90c13f.diff

LOG: [MLIR][XeGPU] Extend unrolling support for scatter ops with chunk_size (#144447)

Add support for load/store with chunk_size, which requires special
consideration for the operand blocking since offests and masks are
 n-D and tensor are n+1-D. Support operations including create_tdesc,
update_tdesc, load, store, and prefetch.

---------

Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
    mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
    mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 9c234c1e866b9..0457f8128b908 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -402,30 +402,58 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getType();
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+    VectorType indiceVecTy = indiceVec.getType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (!tdescTy.isScattered())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
 
-    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
-
-    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
-    VectorType indiceVecTy = indiceVec.getType();
+    SmallVector<int64_t> targetIndiceShape(*targetShape);
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+    // IndiceVec is 1 dim lower than tdescTy when chunkSize is larger than 1.
+    if (originalChunkSize > 1)
+      targetIndiceShape.pop_back();
 
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
     SmallVector<Type> convertedIndiceTypes =
-        getUnrolledTypes(indiceVecTy, *targetShape);
+        getUnrolledTypes(indiceVecTy, targetIndiceShape);
     SmallVector<Value> convertedIndiceVec =
-        pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+        pack(indiceVec, convertedIndiceTypes, targetIndiceShape, loc, rewriter);
 
     SmallVector<Value> newOps;
-    for (auto indice : convertedIndiceVec) {
-      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
-                                                        op.getSource(), indice);
-      newOps.push_back(newOp);
+
+    // More indices is need when chunkSize > 1. Since a big load from one
+    // address could be break into multiple small loads.
+    if (originalChunkSize > 1) {
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+      for (auto [indice, indiceType] :
+           llvm::zip(convertedIndiceVec, convertedIndiceTypes)) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          // Compute the offset
+          Value inc = rewriter.create<arith::ConstantIndexOp>(
+              loc, i * blockedChunkSize);
+          Value incVec = rewriter.create<vector::SplatOp>(loc, indiceType, inc);
+          Value offsetIndice =
+              rewriter.create<arith::AddIOp>(loc, indice, incVec);
+
+          auto newOp = rewriter.create<xegpu::CreateDescOp>(
+              loc, newTdescTy, op.getSource(), offsetIndice);
+
+          newOps.push_back(newOp);
+        }
+      }
+    } else {
+      for (auto indice : convertedIndiceVec) {
+        auto newOp = rewriter.create<xegpu::CreateDescOp>(
+            loc, newTdescTy, op.getSource(), indice);
+        newOps.push_back(newOp);
+      }
     }
 
     Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
@@ -444,16 +472,18 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (!tdescTy.isScattered())
       return failure();
 
-    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
 
+    SmallVector<int64_t> targetMaskShape(*targetShape);
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
     Type elemTy = tdescTy.getElementType();
     VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
 
@@ -462,10 +492,29 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Type> convertedMaskTypes =
-        getUnrolledTypes(maskTy, *targetShape);
-    SmallVector<Value> convertedMasks =
-        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
+
+    if (originalChunkSize > 1) {
+      targetMaskShape.pop_back();
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      SmallVector<Value> convertedMasks1D = pack(
+          op.getMask(), convertedMaskTypes, targetMaskShape, loc, rewriter);
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+      for (auto mask : convertedMasks1D) {
+        for (int64_t i = 0; i < numNewChunks; ++i)
+          convertedMasks.push_back(mask);
+      }
+      // This is to handle the transpose effect when chunkSize > 1.
+      std::swap((*targetShape)[0], (*targetShape)[1]);
+      newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+    } else {
+      convertedMaskTypes = getUnrolledTypes(maskTy, targetMaskShape);
+      convertedMasks = pack(op.getMask(), convertedMaskTypes, targetMaskShape,
+                            loc, rewriter);
+    }
 
     SmallVector<Value> newOps;
     for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
@@ -476,7 +525,6 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     }
 
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
-
     rewriter.replaceOp(op, castOp);
     return success();
   }
@@ -489,8 +537,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (!tdescTy.isScattered())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -519,30 +566,51 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (!tdescTy.isScattered())
       return failure();
 
-    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
-
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
 
-    SmallVector<Type> convertedValTypes =
-        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<int64_t> targetIndiceShape(*targetShape);
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
-
-    SmallVector<Value> convertedValues =
-        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Type> convertedMaskTypes =
-        getUnrolledTypes(maskTy, *targetShape);
-    SmallVector<Value> convertedMasks =
-        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    SmallVector<Type> convertedMaskTypes;
+    SmallVector<Value> convertedMasks;
+
+    if (originalChunkSize > 1) {
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+      convertedMaskTypes = getUnrolledTypes(maskTy, (*targetShape)[0]);
+      SmallVector<Value> convertedMasks1D = pack(
+          op.getMask(), convertedMaskTypes, (*targetShape)[0], loc, rewriter);
+
+      for (auto mask : convertedMasks1D) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          convertedMasks.push_back(mask);
+        }
+      }
+      // This is to handle the transpose effect when chunkSize > 1.
+      std::swap((*targetShape)[0], (*targetShape)[1]);
+
+    } else {
+      convertedMaskTypes = getUnrolledTypes(maskTy, *targetShape);
+      convertedMasks =
+          pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    }
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
 
     for (size_t i = 0; i < convertedValues.size(); ++i) {
       Value v = convertedValues[i];
@@ -565,8 +633,10 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    // check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1)
+    if (tdescTy.getRank() > 2)
+      return failure();
+
+    if (!tdescTy.isScattered())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -580,12 +650,32 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
 
     TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
     VectorType offsetVecTy = offsetVec.getType();
-    SmallVector<Type> convertedOffsetTypes =
-        getUnrolledTypes(offsetVecTy, *targetShape);
-    SmallVector<Value> convertedOffsetVec =
-        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
-
+    SmallVector<Type> convertedOffsetTypes;
+    SmallVector<Value> convertedOffsetVec;
     SmallVector<Value> newOps;
+    int64_t originalChunkSize = tdescTy.getChunkSize();
+    if (originalChunkSize > 1) {
+      SmallVector<int64_t> shape1D(targetShape->begin(),
+                                   targetShape->end() - 1);
+      convertedOffsetTypes = getUnrolledTypes(offsetVecTy, shape1D);
+      SmallVector<Value> convertedOffsetVec1D =
+          pack(offsetVec, convertedOffsetTypes, shape1D, loc, rewriter);
+
+      int64_t blockedChunkSize = targetShape->back();
+      int64_t numNewChunks = originalChunkSize / blockedChunkSize;
+
+      for (auto offset : convertedOffsetVec1D) {
+        for (int64_t i = 0; i < numNewChunks; ++i) {
+          convertedOffsetVec.push_back(offset);
+        }
+      }
+
+    } else {
+      convertedOffsetTypes = getUnrolledTypes(offsetVecTy, *targetShape);
+      convertedOffsetVec =
+          pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
+    }
+
     for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
       auto newOp =
           rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 52ec3b856da49..41414d802f212 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -2,7 +2,7 @@
 
 gpu.module @test {
 
-  // CHECK-LABEL: test_create_nd_tdesc
+  // CHECK-LABEL: create_nd_tdesc
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
@@ -10,31 +10,31 @@ gpu.module @test {
   // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>,
   // CHECK-SAME: !xegpu.tensor_desc<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
   // CHECK-SAME: to !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {__xegpu_blocking_tile_shape__ = array<i64: 8, 16>, __xegpu_blocking_unpack__}
-  gpu.func @test_create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+  gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     gpu.return %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
   }
 
   //-----
 
-  // CHECK-LABEL: test_create_nd_tdesc_1d
+  // CHECK-LABEL: create_nd_tdesc_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK: [[cast:%.+]] = builtin.unrealized_conversion_cast
   // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
   // CHECK-SAME: to !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {__xegpu_blocking_tile_shape__ = array<i64: 16>, __xegpu_blocking_unpack__}
-  gpu.func @test_create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+  gpu.func @create_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
     gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
   }
 
   //-----
 
-  // CHECK-LABEL: test_update_nd_tdesc
+  // CHECK-LABEL: update_nd_tdesc
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK-COUNT-6: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x16xf32>
-  gpu.func @test_update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
+  gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     %update = xegpu.update_nd_offset %tdesc, [0, 16] : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     gpu.return %update : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -42,11 +42,11 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_update_nd_tdesc_1d
+  // CHECK-LABEL: update_nd_tdesc_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-2: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK-COUNT-2: [[update:%.+]] = xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<16xf32>
-  gpu.func @test_update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
+  gpu.func @update_nd_tdesc_1d(%src: memref<64xf32>) -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>> {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
     %update = xegpu.update_nd_offset %tdesc, [32] : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
     gpu.return %update : !xegpu.tensor_desc<32xf32, #xegpu.layout<inst_data = [16]>>
@@ -54,11 +54,11 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_prefetch_nd_tdesc
+  // CHECK-LABEL: prefetch_nd_tdesc
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK-COUNT-6: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<8x16xf32>
-  gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {
+  gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     gpu.return
@@ -66,23 +66,23 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_prefetch_nd_tdesc_1d
+  // CHECK-LABEL: prefetch_nd_tdesc_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK-COUNT-4: xegpu.prefetch_nd {{.*}} : !xegpu.tensor_desc<16xf32>
-  gpu.func @test_prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
+  gpu.func @prefetch_nd_tdesc_1d(%src: memref<64xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
     xegpu.prefetch_nd %tdesc : !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
     gpu.return
   }
 
   //-----
-  // CHECK-LABEL: test_load_nd
+  // CHECK-LABEL: load_nd
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK-COUNT-6: [[ld:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
   // CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
-  gpu.func @test_load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
+  gpu.func @load_nd(%src: memref<24x32xf32>) -> vector<24x32xf32> {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
     gpu.return %ld : vector<24x32xf32>
@@ -90,12 +90,12 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_load_nd_1d
+  // CHECK-LABEL: load_nd_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK-COUNT-4: [[ld:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
   // CHECK-COUNT-4: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<16xf32> into vector<64xf32>
-  gpu.func @test_load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
+  gpu.func @load_nd_1d(%src: memref<64xf32>) -> vector<64xf32> {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
     %data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>> -> vector<64xf32>
     gpu.return %data : vector<64xf32>
@@ -103,11 +103,11 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_store_nd
+  // CHECK-LABEL: store_nd
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   // CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   // CHECK-COUNT-6: xegpu.store_nd {{.*}}  : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.func @test_store_nd(%src: memref<24x32xf32>) {
+  gpu.func @store_nd(%src: memref<24x32xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     %data = arith.constant dense<9.0> : vector<24x32xf32>
     xegpu.store_nd %data, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
@@ -116,11 +116,11 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_store_nd_1d
+  // CHECK-LABEL: store_nd_1d
   // CHECK-SAME: [[arg0:%.+]]: memref<64xf32>
   // CHECK-COUNT-4: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<64xf32> -> !xegpu.tensor_desc<16xf32>
   // CHECK-COUNT-4: xegpu.store_nd {{.*}}  : vector<16xf32>, !xegpu.tensor_desc<16xf32>
-  gpu.func @test_store_nd_1d(%src: memref<64xf32>) {
+  gpu.func @store_nd_1d(%src: memref<64xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0] : memref<64xf32> -> !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
     %data = arith.constant dense<9.0> : vector<64xf32>
     xegpu.store_nd %data, %tdesc: vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.layout<inst_data = [16]>>
@@ -129,7 +129,7 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_createNd_loadNd_storeNd
+  // CHECK-LABEL: createNd_loadNd_storeNd
   // CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
   //CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
   //CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
@@ -137,7 +137,7 @@ gpu.module @test {
   //CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
   //CHECK-COUNT-6: [[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
   //CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
-  gpu.func @test_createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
+  gpu.func @createNd_loadNd_storeNd(%src: memref<24x32xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
     %data = arith.constant dense<9.0> : vector<24x32xf32>
     %ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
@@ -148,23 +148,23 @@ gpu.module @test {
 
   //-----
 
-  // CHECK-LABEL: test_dpas
+  // CHECK-LABEL: dpas
   // CHECK-SAME: [[arg0:%.+]]: vector<32x32xf16>, [[arg1:%.+]]: vector<32x32xf16>
   //CHECK-COUNT-8: [[extract1:%.+]] = vector.extract_strided_slice [[arg0]] {{.*}} : vector<32x32xf16> to vector<8x16xf16>
   //CHECK-COUNT-4: [[extract2:%.+]] = vector.extract_strided_slice [[arg1]] {{.*}} : vector<32x32xf16> to vector<16x16xf16>
   //CHECK-COUNT-16: [[dpas:%.+]] = xegpu.dpas {{.*}} -> vector<8x16xf32>
   //CHECK-COUNT-8: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
-  gpu.func @test_dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
+  gpu.func @dpas(%a: vector<32x32xf16>, %b: vector<32x32xf16>) -> vector<32x32xf32> {
     %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
     gpu.return %c : vector<32x32xf32>
   }
 
 //-----
 
-  // CHECK-LABEL: test_create_tdesc_vec
+  // CHECK-LABEL: create_tdesc_vec
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+  gpu.func @create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
     %cst = arith.constant dense<[
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
@@ -177,10 +177,10 @@ gpu.module @test {
 
 //-----
 
-  // CHECK-LABEL: test_create_tdesc_step
+  // CHECK-LABEL: create_tdesc_step
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+  gpu.func @create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
     %step = arith.constant dense<8> : vector<32xindex>
     %seq = vector.step  : vector<32xindex>
     %cst = arith.muli %seq, %step : vector<32xindex>
@@ -190,11 +190,11 @@ gpu.module @test {
 
 //-----
 
-  // CHECK-LABEL: test_load
+  // CHECK-LABEL: load
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-  gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+  gpu.func @load(%src: ui64) -> vector<32xf32> {
     %cst = arith.constant dense<[
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
@@ -212,11 +212,11 @@ gpu.module @test {
 
 //-----
 
-  // CHECK-LABEL: test_prefetch
+  // CHECK-LABEL: prefetch
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
   // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  gpu.func @test_prefetch(%src: ui64)  {
+  gpu.func @prefetch(%src: ui64)  {
 
     %cst = arith.constant dense<[
     0,   8,  16,  24,  32,  40,  48,  56,
@@ -233,11 +233,11 @@ gpu.module @test {
 
 //-----
 
-  // CHECK-LABEL: test_store
+  // CHECK-LABEL: store
   // CHECK-SAME: [[arg0:%.+]]: ui64
   // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
   // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
-  gpu.func @test_store(%src: ui64) {
+  gpu.func @store(%src: ui64) {
     %cst = arith.constant dense<[
     0,   8,  16,  24,  32,  40,  48,  56,
     64,  72,  80,  88,  96, 104, 112, 120,
@@ -256,47 +256,129 @@ gpu.module @test {
   }
 
 //-----
+  // CHECK-LABEL: create_tdesc_step_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr<chunk_size = 4 : i64>>
+  gpu.func @create_tdesc_step_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>> {
+    %step = arith.constant dense<8> : vector<32xindex>
+    %seq = vector.step  : vector<32xindex>
+    %cst = arith.muli %seq, %step : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 4]>>
+  }
 
-  // CHECK-LABEL: test_prefetch_load_store_update
+//-----
+  // CHECK-LABEL: create_tdesc_step_chunk2
   // CHECK-SAME: [[arg0:%.+]]: ui64
-  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-   // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
-   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  gpu.func @create_tdesc_step_chunk2(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+    %step = arith.constant dense<8> : vector<32xindex>
+    %seq = vector.step  : vector<32xindex>
+    %cst = arith.muli %seq, %step : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+  }
 
-  gpu.func @test_prefetch_load_store_update(%src: ui64)  {
+// CHECK-LABEL: create_tdesc_step_chunk3
+  // CHECK-SAME: [[arg0:%.+]]: ui64  
+  // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+  // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+  // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+ // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<16xindex>
+  // CHECK: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+    gpu.func @create_tdesc_step_chunk3(%src: ui64) -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>> {
+    %step = arith.constant dense<8> : vector<16xindex>
+    %seq = vector.step  : vector<16xindex>
+    %cst = arith.muli %seq, %step : vector<16xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32,  #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<16x8xf32,  #xegpu.scatter_tdesc_attr<chunk_size=8>, #xegpu.layout<inst_data = [16, 2]>>
+  }
+
+//-----
+  // CHECK-LABEL: load_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-4: xegpu.load  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
 
+  gpu.func @load_chunk(%src: ui64) -> vector<4x32xf32> {
     %cst = arith.constant dense<[
-    0,   8,  16,  24,  32,  40,  48,  56,
-    64,  72,  80,  88,  96, 104, 112, 120,
-    128, 136, 144, 152, 160, 168, 176, 184,
-    192, 200, 208, 216, 224, 232, 240, 248 
+        0,   8,  16,  24,  32,  40,  48,  56,
+        64,  72,  80,  88,  96, 104, 112, 120,
+        128, 136, 144, 152, 160, 168, 176, 184,
+        192, 200, 208, 216, 224, 232, 240, 248 
     ]> : vector<32xindex>
+    
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
 
-    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-   
-    %delta = arith.constant dense<[
-    32,   32,  32,  32,  32,  32,  32,  32,
-    32,   32,  32,  32,  32,  32,  32,  64,
-    128, 128, 128, 128, 128, 128, 128, 128,
-    128, 128, 128, 128, 128, 128, 128, 256 
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> 
+    %ld = xegpu.load %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
+    
+    gpu.return %ld : vector<4x32xf32> 
+   }
+
+//-----
+  // CHECK-LABEL: store_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} :  ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-4: xegpu.store  {{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}> : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+  gpu.func @store_chunk(%src: ui64) {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
     ]> : vector<32xindex>
-    %new_tdesc = xegpu.update_offset %tdesc, %delta
-              : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
- 
+    
     %c17 = arith.constant 17: index
     %mask = vector.create_mask %c17: vector<32xi1>
 
-    %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+    %st_vec = arith.constant dense<1023.>: vector<4x32xf32>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: vector<4x32xf32>, !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16,2]>>, vector<32xi1>
+    
+    gpu.return
+  }
 
-    %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
-    xegpu.store %st_vec, %tdesc, %mask: 
-                 vector<32xf32>, 
-                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
-                 vector<32xi1>
-  
+//-----
+  // CHECK-LABEL: prefetch_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  gpu.func @prefetch_chunk(%src: ui64)  {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    
     gpu.return
   }
+
+//-----
+  // CHECK-LABEL: update_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
+  gpu.func @update_chunk(%src: ui64) -> !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>> {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    %delta = arith.constant dense<32>: vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+
+    %new_tdesc = xegpu.update_offset %tdesc, %delta
+        : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>
+
+    gpu.return %new_tdesc : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+  }  
 }
+

diff  --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 57aaecbd7962f..4400d6d9625f7 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -19,6 +19,10 @@ using namespace mlir::xegpu;
 
 namespace {
 
+#define DEBUG_TYPE "test-xegpu-unroll"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 struct TestXeGPUUnrollingPatterns
     : public PassWrapper<TestXeGPUUnrollingPatterns,
                          OperationPass<gpu::GPUModuleOp>> {
@@ -48,7 +52,9 @@ struct TestXeGPUUnrollingPatterns
     options.setNativeShapeFn(
         [&](Operation *op) -> std::optional<SmallVector<int64_t>> {
           if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,
-                  xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
+                  xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+                  xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
             xegpu::TensorDescType tdescTy;
             if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
               tdescTy = createNdOp.getType();
@@ -61,20 +67,7 @@ struct TestXeGPUUnrollingPatterns
               tdescTy = loadNdOp.getTensorDescType();
             } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
               tdescTy = storeNdOp.getTensorDescType();
-            }
-
-            if (auto layout = tdescTy.getLayoutAttr()) {
-              auto inst_data = layout.getInstData();
-              if (inst_data && layout.isSgLayout())
-                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
-                                            inst_data.asArrayRef().end());
-            }
-          }
-
-          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
-                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
-            xegpu::TensorDescType tdescTy;
-            if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+            } else if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
               tdescTy = createOp.getType();
             } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
               tdescTy = updateOp.getTensorDescType();
@@ -111,14 +104,40 @@ struct TestXeGPUUnrollingPatterns
             Attribute encoding = tdescTy.getEncoding();
             auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
                 tdescTy.getLayout());
+
+            // If the encoding is a ScatterTensorDescAttr, we need to
+            // potentially adjust the chunk size based on the inst_data.
+            if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
+              auto scatterAttr =
+                  mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+              int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+
+              if (chunkSize > 1) {
+                int64_t blockedChunkSize = chunkSize;
+                auto instData = layout.getInstData();
+                if (!instData.empty())
+                  blockedChunkSize = instData.asArrayRef().back();
+
+                auto chunkSizeAttr = mlir::IntegerAttr::get(
+                    mlir::IntegerType::get(ctx, 64), blockedChunkSize);
+
+                // To create a new attribute with a 
diff erent chunk_size:
+                auto newEncoding = xegpu::ScatterTensorDescAttr::get(
+                    ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+
+                encoding = newEncoding;
+              }
+            }
             if (layout) {
               if (layout.getLaneLayout() == nullptr)
                 layout = xegpu::LayoutAttr();
               else
                 layout = layout.dropInstData();
             }
+
             newTy = xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
                                                layout);
+
           } else {
             newTy = type.clone(tileShape, elemTy);
           }


        


More information about the Mlir-commits mailing list