[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns for scatter ops (PR #143602)

Chao Chen llvmlistbot at llvm.org
Mon Jun 16 08:32:30 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/143602

>From bac8bc607a0b5f9171472375e5f42c24a5ad429d Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Jun 2025 14:40:41 +0000
Subject: [PATCH 01/10] add unroll supportfor reduce and broadcast

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |   6 +
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   | 110 ++++++++++++++++++
 2 files changed, 116 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 7cd998eed2e08..a3826c56e1f62 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -169,6 +169,12 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
   if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1)
     return getTileShape(op->getOpResult(0));
 
+  if (isa<vector::MultiDimReductionOp>(op))
+    return getTileShape(op->getOpOperand(0));
+
+  if (isa<vector::TransposeOp, vector::BroadcastOp>(op))
+    return getTileShape(op->getOpResult(0));
+
   return std::nullopt;
 }
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index f9114988686c8..8e3673d04eacb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -246,3 +246,113 @@ gpu.module @test_kernel {
     gpu.return
   }
 }
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 16]>
+#r = #xegpu.layout<inst_data = [16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+    %acc = arith.constant dense<0.0> : vector<64xf32>
+    %c64 = arith.constant 64 : index
+    %block_id_x = gpu.block_id x
+    %m = arith.muli %block_id_x, %c64 : index
+    %0 = xegpu.create_nd_tdesc %a[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<16x64xf32, #l> -> vector<16x64xf32>
+    // CHECK: vector.multi_reduction <add>, {{.*}}, [[ACC:%[0-9A-Za-z]+]] [0] : vector<16x16xf32> to vector<16xf32>
+    // CHECK-COUNT-3: vector.multi_reduction <add>, {{.*}}, [[ACC]] [0] : vector<16x16xf32> to vector<16xf32>
+    %2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [0]: vector<16x64xf32> to vector<64xf32>
+    %3 = xegpu.create_nd_tdesc %b[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
+    xegpu.store_nd %2, %3: vector<64xf32>, !xegpu.tensor_desc<64xf32, #r>
+    gpu.return
+  }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 16]>
+#r = #xegpu.layout<inst_data = [16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+    %c1 = arith.constant 1 : index
+    %c32 = arith.constant 32 : index
+    %acc = arith.constant dense<0.0> : vector<32xf32>
+
+    %block_id_x = gpu.block_id x
+    %block_id_y = gpu.block_id y
+
+    %m = arith.muli %block_id_x, %c32 : index
+    %n = arith.muli %block_id_y, %c32 : index
+    %0 = xegpu.create_nd_tdesc %a[%m, %n] : memref<512x32xf32> -> !xegpu.tensor_desc<32x128xf32, #l>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x128xf32, #l> -> vector<32x128xf32>
+
+    // CHECK: vector.multi_reduction <add>, {{.*}}, [[INIT:%[0-9A-Za-z]+]] [1] : vector<16x16xf32> to vector<16xf32>
+    // CHECK-COUNT-1: vector.multi_reduction <add>, {{.*}}, [[INIT]] [1] : vector<16x16xf32> to vector<16xf32>
+
+    %2 = vector.multi_reduction <add>, %1, %acc {layout_result_0 = #r} [1]: vector<32x128xf32> to vector<32xf32>
+    %3 = xegpu.create_nd_tdesc %b[%n] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
+    xegpu.store_nd %2, %3: vector<32xf32>, !xegpu.tensor_desc<32xf32, #r>
+    gpu.return
+  }
+}
+
+// -----
+#r = #xegpu.layout<inst_data = [16]>
+#l = #xegpu.layout<inst_data = [16, 16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+    %c64 = arith.constant 64 : index
+    %block_id_x = gpu.block_id x
+    %m = arith.muli %block_id_x, %c64 : index
+    %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<64xf32, #r>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<64xf32, #r> -> vector<64xf32>
+    // CHECK-COUNT-4: vector.broadcast {{.*}} : vector<16xf32> to vector<16x16xf32>
+    %2 = vector.broadcast  %1 {layout_result_0 = #l} : vector<64xf32> to vector<16x64xf32>
+    %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<16x64xf32, #l>
+    xegpu.store_nd %2, %3: vector<16x64xf32>, !xegpu.tensor_desc<16x64xf32, #l>
+    gpu.return
+  }
+}
+
+// -----
+#r = #xegpu.layout<inst_data = [16]>
+#l = #xegpu.layout<inst_data = [16, 16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+    %c32 = arith.constant 32 : index
+    %block_id_x = gpu.block_id x
+    %m = arith.muli %block_id_x, %c32 : index
+    %0 = xegpu.create_nd_tdesc %a[%m] : memref<512xf32> -> !xegpu.tensor_desc<32xf32, #r>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32xf32, #r> -> vector<32xf32>
+    %11 = vector.shape_cast %1 :  vector<32xf32> to vector<32x1xf32>
+    // CHECK-COUNT-8: vector.broadcast {{.*}}: vector<16x1xf32> to vector<16x16xf32>
+    %2 = vector.broadcast  %11 {layout_result_0 = #l} : vector<32x1xf32> to vector<32x64xf32>
+    %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<16x512xf32> -> !xegpu.tensor_desc<32x64xf32, #l>
+    xegpu.store_nd %2, %3: vector<32x64xf32>, !xegpu.tensor_desc<32x64xf32, #l>
+    gpu.return
+  }
+}
+
+// -----
+#l = #xegpu.layout<inst_data = [16, 8]>
+#t = #xegpu.layout<inst_data = [8, 16]>
+
+gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+  gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+
+    %c32 = arith.constant 32 : index
+    %block_id_x = gpu.block_id x
+    %m = arith.muli %block_id_x, %c32 : index
+    %0 = xegpu.create_nd_tdesc %a[%m, 0] : memref<512x8xf32> -> !xegpu.tensor_desc<32x8xf32, #l>
+    %1 = xegpu.load_nd %0: !xegpu.tensor_desc<32x8xf32, #l> -> vector<32x8xf32>
+    // CHECK-COUNT-2: vector.transpose {{.*}}  [1, 0] : vector<16x8xf32> to vector<8x16xf32>
+    %2 = vector.transpose  %1, [1, 0] {layout_result_0 = #t} : vector<32x8xf32> to vector<8x32xf32>
+    %3 = xegpu.create_nd_tdesc %b[0, %m] : memref<8x512xf32> -> !xegpu.tensor_desc<8x32xf32, #t>
+    xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
+    gpu.return
+  }
+}
\ No newline at end of file

>From 30b099ef6d1f8d036290590c73786bd45ea51b57 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Jun 2025 23:14:58 +0000
Subject: [PATCH 02/10] add create_desc unrolling and test

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 41 ++++++++++++++++++-
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 24 +++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 885477fe4cbd5..672b0fb731f31 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -396,11 +396,50 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
   }
 };
 
+struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
+  using UnrollPattern<xegpu::CreateDescOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::CreateDescOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
+
+    
+    TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
+
+    VectorType indiceVecTy = indiceVec.getType();
+
+    SmallVector<Type> convertedIndiceTypes =
+        getUnrolledTypes(indiceVecTy, *targetShape);
+
+    SmallVector<Value> convertedIndiceVec =
+        pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+
+    for (auto indice : convertedIndiceVec) {
+      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, tdescTy, *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
   patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
-               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
+               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, 
+               UnrollCreateDescOp>(
       patterns.getContext(), options);
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 3f3461e92bc08..abdee098ab430 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,6 +71,30 @@ struct TestXeGPUUnrollingPatterns
             }
           }
 
+          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp,
+                  xegpu::PrefetchOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+            xegpu::TensorDescType tdescTy;
+            if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
+              tdescTy = createOp.getType();
+            } else if (auto updateOp =
+                           dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+              tdescTy = updateOp.getTensorDescType();
+            } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
+              tdescTy = prefetchOp.getTensorDescType();
+            } else if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+              tdescTy = loadOp.getTensorDescType();
+            } else if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(op)) {
+              tdescTy = storeOp.getTensorDescType();
+            }
+
+            if (auto layout = tdescTy.getLayoutAttr()) {
+              auto inst_data = layout.getInstData();
+              if (inst_data && layout.isSgLayout())
+                return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
+                                            inst_data.asArrayRef().end());
+            }
+          }
+
           if (isa<xegpu::DpasOp>(op))
             return SmallVector<int64_t>{8, 16, 16};
 

>From 8a0e1455683d2a04744bb653348e39c353060704 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Jun 2025 19:49:39 +0000
Subject: [PATCH 03/10] add unrolling support scatter operations

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 159 +++++++++++++++++-
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 143 ++++++++++++++++
 2 files changed, 296 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 672b0fb731f31..96b86f6509419 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -409,19 +409,14 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
 
     auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
 
-    
     TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
-
     VectorType indiceVecTy = indiceVec.getType();
-
     SmallVector<Type> convertedIndiceTypes =
         getUnrolledTypes(indiceVecTy, *targetShape);
-
     SmallVector<Value> convertedIndiceVec =
         pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
 
     SmallVector<Value> newOps;
-
     for (auto indice : convertedIndiceVec) {
       auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
       newOps.push_back(newOp);
@@ -434,12 +429,164 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
   }
 };
 
+struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
+  using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    Type elemTy = tdescTy.getElementType();
+    VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Type> convertedMaskTypes =
+        getUnrolledTypes(maskTy, *targetShape);
+    SmallVector<Value> convertedMasks = pack(
+        op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> newOps;
+    for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
+      auto newOp =
+          rewriter.create<xegpu::LoadGatherOp>(loc, newValueTy, t, m,  
+            op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+      newOps.push_back(newOp);
+    }
+
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
+struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
+  using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    for (auto t : convertedTdesc)
+      rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
+  using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = op.getLoc();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    VectorType maskTy;
+    if (op.getMask())
+      maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedValTypes =
+        getUnrolledTypes(valueTy, *targetShape);
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+
+    SmallVector<Value> convertedValues =
+        pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
+    SmallVector<Value> convertedTdescs =
+        pack(op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    SmallVector<Value> convertedMasks;
+    if (op.getMask()) {
+      SmallVector<Type> convertedMaskTypes =
+          getUnrolledTypes(maskTy, *targetShape);
+      convertedMasks =
+          pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    }
+
+    for (size_t i = 0; i < convertedValues.size(); ++i) {
+      Value v = convertedValues[i];
+      Value t = convertedTdescs[i];
+      Value m = op.getMask() ? convertedMasks[i] : nullptr;
+      rewriter.create<xegpu::StoreScatterOp>(
+          loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+    }
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
+  using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
+  LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    xegpu::TensorDescType tdescTy = op.getTensorDescType();
+
+    std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
+    if (!targetShape)
+      return failure();
+
+    SmallVector<Type> convertedTdescTypes =
+        getUnrolledTypes(tdescTy, *targetShape);
+    SmallVector<Value> convertedTdesc = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+
+    TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
+    VectorType offsetVecTy = offsetVec.getType();
+    SmallVector<Type> convertedOffsetTypes =
+        getUnrolledTypes(offsetVecTy, *targetShape);
+    SmallVector<Value> convertedOffsetVec =
+        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);        
+
+    SmallVector<Value> newOps;
+    for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
+      auto newOp = rewriter.create<xegpu::UpdateOffsetOp>(
+          loc, t.getType(), t, o);
+      newOps.push_back(newOp);
+    }
+    Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
+    rewriter.replaceOp(op, castOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
   patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
                UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, 
-               UnrollCreateDescOp>(
+               UnrollCreateDescOp, UnrollLoadGatherOp, 
+               UnrollStoreScatterOp, UnrollPrefetchOp, UnrollUpdateOffsetOp>(
       patterns.getContext(), options);
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index b911bb3bbdc1c..47c54bfcb89d0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -158,4 +158,147 @@ gpu.module @test {
     %c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
     gpu.return %c : vector<32x32xf32>
   }
+
+//-----
+
+  // CHECK-LABEL: test_create_tdesc_vec
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>,  #xegpu.layout<inst_data = [16]>>
+  }
+
+//-----
+
+  // CHECK-LABEL: test_create_tdesc_step
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
+    %step = arith.constant dense<8> : vector<32xindex>
+    %seq = vector.step  : vector<32xindex>
+    %cst = arith.muli %seq, %step : vector<32xindex>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return %tdesc : !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+  }
+
+//-----
+
+    // CHECK-LABEL: test_load
+    // CHECK-SAME: [[arg0:%.+]]: ui64
+    // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+    // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+    gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+      
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+      %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+      
+      gpu.return %ld : vector<32xf32> 
+    }
+
+//-----
+
+  // CHECK-LABEL: test_prefetch
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  gpu.func @test_prefetch(%src: ui64)  {
+
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+
+      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+      gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: test_store
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+  gpu.func @test_store(%src: ui64) {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+    
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %st_vec = arith.constant dense<1023.>: vector<32xf32>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
+    
+    gpu.return
+  }
+
+//-----
+
+  // CHECK-LABEL: test_prefetch_load_store_update
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+   // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+  gpu.func @test_prefetch_load_store_update(%src: ui64)  {
+
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+      ]> : vector<32xindex>
+
+      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+
+      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+   
+      %delta = arith.constant dense<[
+      32,   32,  32,  32,  32,  32,  32,  32,
+      32,   32,  32,  32,  32,  32,  32,  64,
+      128, 128, 128, 128, 128, 128, 128, 128,
+      128, 128, 128, 128, 128, 128, 128, 256 
+      ]> : vector<32xindex>
+      %new_tdesc = xegpu.update_offset %tdesc, %delta
+           : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
+ 
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+      %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+      xegpu.store %st_vec, %tdesc, %mask: 
+                 vector<32xf32>, 
+                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
+                 vector<32xi1>
+  
+      gpu.return
+    }
 }

>From c91156c7fc789548676f367919ddd577ba06c5ed Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Jun 2025 19:51:32 +0000
Subject: [PATCH 04/10] clang format

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 31 ++++++++++---------
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  7 ++---
 2 files changed, 19 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 96b86f6509419..900ade8c171d5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -418,7 +418,8 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
 
     SmallVector<Value> newOps;
     for (auto indice : convertedIndiceVec) {
-      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
+      auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy,
+                                                        op.getSource(), indice);
       newOps.push_back(newOp);
     }
 
@@ -454,14 +455,14 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
 
     SmallVector<Type> convertedMaskTypes =
         getUnrolledTypes(maskTy, *targetShape);
-    SmallVector<Value> convertedMasks = pack(
-        op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+    SmallVector<Value> convertedMasks =
+        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
 
     SmallVector<Value> newOps;
     for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
-      auto newOp =
-          rewriter.create<xegpu::LoadGatherOp>(loc, newValueTy, t, m,  
-            op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+      auto newOp = rewriter.create<xegpu::LoadGatherOp>(
+          loc, newValueTy, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
       newOps.push_back(newOp);
     }
 
@@ -520,8 +521,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
 
     SmallVector<Value> convertedValues =
         pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
-    SmallVector<Value> convertedTdescs =
-        pack(op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+    SmallVector<Value> convertedTdescs = pack(
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
     SmallVector<Value> convertedMasks;
     if (op.getMask()) {
@@ -566,12 +567,12 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     SmallVector<Type> convertedOffsetTypes =
         getUnrolledTypes(offsetVecTy, *targetShape);
     SmallVector<Value> convertedOffsetVec =
-        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);        
+        pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
 
     SmallVector<Value> newOps;
     for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
-      auto newOp = rewriter.create<xegpu::UpdateOffsetOp>(
-          loc, t.getType(), t, o);
+      auto newOp =
+          rewriter.create<xegpu::UpdateOffsetOp>(loc, t.getType(), t, o);
       newOps.push_back(newOp);
     }
     Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
@@ -585,8 +586,8 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
 void mlir::xegpu::populateXeGPUUnrollPatterns(
     RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
   patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
-               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp, 
-               UnrollCreateDescOp, UnrollLoadGatherOp, 
-               UnrollStoreScatterOp, UnrollPrefetchOp, UnrollUpdateOffsetOp>(
-      patterns.getContext(), options);
+               UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
+               UnrollCreateDescOp, UnrollLoadGatherOp, UnrollStoreScatterOp,
+               UnrollPrefetchOp, UnrollUpdateOffsetOp>(patterns.getContext(),
+                                                       options);
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index abdee098ab430..57aaecbd7962f 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -71,13 +71,12 @@ struct TestXeGPUUnrollingPatterns
             }
           }
 
-          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp,
-                  xegpu::PrefetchOp, xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
+          if (isa<xegpu::CreateDescOp, xegpu::UpdateOffsetOp, xegpu::PrefetchOp,
+                  xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op)) {
             xegpu::TensorDescType tdescTy;
             if (auto createOp = dyn_cast<xegpu::CreateDescOp>(op)) {
               tdescTy = createOp.getType();
-            } else if (auto updateOp =
-                           dyn_cast<xegpu::UpdateOffsetOp>(op)) {
+            } else if (auto updateOp = dyn_cast<xegpu::UpdateOffsetOp>(op)) {
               tdescTy = updateOp.getTensorDescType();
             } else if (auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(op)) {
               tdescTy = prefetchOp.getTensorDescType();

>From 30cb8d816352b0d29578f408777a151c94b5f3be Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 12 Jun 2025 20:59:47 -0700
Subject: [PATCH 05/10] Update
 mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir

Co-authored-by: Adam Siemieniuk <adam.siemieniuk at intel.com>
---
 mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index 47c54bfcb89d0..cca5339b086e4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -249,7 +249,7 @@ gpu.module @test {
     %c17 = arith.constant 17: index
     %mask = vector.create_mask %c17: vector<32xi1>
 
-    %st_vec = arith.constant dense<1023.>: vector<32xf32>
+    %st_vec = arith.constant dense<1023.0>: vector<32xf32>
     %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
     xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
     

>From f493e52fefa594aa5918471bf0aa57b2cc0331a4 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 12 Jun 2025 21:06:13 -0700
Subject: [PATCH 06/10] Update xegpu-unroll-patterns.mlir

correct indentation
---
 .../Dialect/XeGPU/xegpu-unroll-patterns.mlir  | 100 +++++++++---------
 1 file changed, 49 insertions(+), 51 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
index cca5339b086e4..52ec3b856da49 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir
@@ -190,26 +190,25 @@ gpu.module @test {
 
 //-----
 
-    // CHECK-LABEL: test_load
-    // CHECK-SAME: [[arg0:%.+]]: ui64
-    // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-    // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
-    gpu.func @test_load(%src: ui64) -> vector<32xf32> {
-      %cst = arith.constant dense<[
-      0,   8,  16,  24,  32,  40,  48,  56,
-      64,  72,  80,  88,  96, 104, 112, 120,
-      128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
-      ]> : vector<32xindex>
+  // CHECK-LABEL: test_load
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  gpu.func @test_load(%src: ui64) -> vector<32xf32> {
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
       
-      %c17 = arith.constant 17: index
-      %mask = vector.create_mask %c17: vector<32xi1>
-
-      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-      %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    %ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
       
-      gpu.return %ld : vector<32xf32> 
-    }
+    gpu.return %ld : vector<32xf32> 
+  }
 
 //-----
 
@@ -219,17 +218,17 @@ gpu.module @test {
   // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
   gpu.func @test_prefetch(%src: ui64)  {
 
-      %cst = arith.constant dense<[
-      0,   8,  16,  24,  32,  40,  48,  56,
-      64,  72,  80,  88,  96, 104, 112, 120,
-      128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
-      ]> : vector<32xindex>
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
 
-      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
 
-      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
-      gpu.return
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    gpu.return
   }
 
 //-----
@@ -268,37 +267,36 @@ gpu.module @test {
 
   gpu.func @test_prefetch_load_store_update(%src: ui64)  {
 
-      %cst = arith.constant dense<[
-      0,   8,  16,  24,  32,  40,  48,  56,
-      64,  72,  80,  88,  96, 104, 112, 120,
-      128, 136, 144, 152, 160, 168, 176, 184,
-      192, 200, 208, 216, 224, 232, 240, 248 
-      ]> : vector<32xindex>
-
-      %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
 
-      xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
    
-      %delta = arith.constant dense<[
-      32,   32,  32,  32,  32,  32,  32,  32,
-      32,   32,  32,  32,  32,  32,  32,  64,
-      128, 128, 128, 128, 128, 128, 128, 128,
-      128, 128, 128, 128, 128, 128, 128, 256 
-      ]> : vector<32xindex>
-      %new_tdesc = xegpu.update_offset %tdesc, %delta
-           : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
+    %delta = arith.constant dense<[
+    32,   32,  32,  32,  32,  32,  32,  32,
+    32,   32,  32,  32,  32,  32,  32,  64,
+    128, 128, 128, 128, 128, 128, 128, 128,
+    128, 128, 128, 128, 128, 128, 128, 256 
+    ]> : vector<32xindex>
+    %new_tdesc = xegpu.update_offset %tdesc, %delta
+              : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
  
-      %c17 = arith.constant 17: index
-      %mask = vector.create_mask %c17: vector<32xi1>
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
 
-      %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+    %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
 
-      %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
-      xegpu.store %st_vec, %tdesc, %mask: 
+    %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+    xegpu.store %st_vec, %tdesc, %mask: 
                  vector<32xf32>, 
                  !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
                  vector<32xi1>
   
-      gpu.return
-    }
+    gpu.return
+  }
 }

>From 2606a4bf20ab35ba9b46054c4dc4255d30619e75 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 13 Jun 2025 21:21:42 +0000
Subject: [PATCH 07/10] reject support for chunk_size

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 36 ++++++++++++++-----
 1 file changed, 28 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 900ade8c171d5..c8e1479f49f11 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -403,6 +403,11 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getType();
 
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
@@ -439,6 +444,11 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
     VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -480,6 +490,11 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();
@@ -506,9 +521,12 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    VectorType maskTy;
-    if (op.getMask())
-      maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
+    VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
@@ -524,13 +542,10 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     SmallVector<Value> convertedTdescs = pack(
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
-    SmallVector<Value> convertedMasks;
-    if (op.getMask()) {
-      SmallVector<Type> convertedMaskTypes =
+    SmallVector<Type> convertedMaskTypes =
           getUnrolledTypes(maskTy, *targetShape);
-      convertedMasks =
+    SmallVector<Value> convertedMasks =
           pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
-    }
 
     for (size_t i = 0; i < convertedValues.size(); ++i) {
       Value v = convertedValues[i];
@@ -553,6 +568,11 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
+    //check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1) {
+      return failure();
+    }
+
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
       return failure();

>From a3e064d4f051a46dbee8960d86685aca0700bbcc Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 13 Jun 2025 21:30:40 +0000
Subject: [PATCH 08/10] small fixes

---
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  | 29 ++++++++-----------
 1 file changed, 12 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c8e1479f49f11..8a0048fbca8ed 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -403,10 +403,9 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
@@ -444,10 +443,9 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
@@ -490,10 +488,9 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)
@@ -521,10 +518,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
 
@@ -543,9 +539,9 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
         op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
     SmallVector<Type> convertedMaskTypes =
-          getUnrolledTypes(maskTy, *targetShape);
+        getUnrolledTypes(maskTy, *targetShape);
     SmallVector<Value> convertedMasks =
-          pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
+        pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
 
     for (size_t i = 0; i < convertedValues.size(); ++i) {
       Value v = convertedValues[i];
@@ -568,10 +564,9 @@ struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    //check if the tensor descriptor type is a 1d vector type
-    if (tdescTy.getRank() > 1) {
+    // check if the tensor descriptor type is a 1d vector type
+    if (tdescTy.getRank() > 1)
       return failure();
-    }
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape)

>From 75a65ead12c0103de730a29873d5c3d5ef6b2445 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 14 Jun 2025 00:55:52 +0000
Subject: [PATCH 09/10] small fix

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 8a0048fbca8ed..11cb2ed478de5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -415,6 +415,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
 
     TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
     VectorType indiceVecTy = indiceVec.getType();
+    
     SmallVector<Type> convertedIndiceTypes =
         getUnrolledTypes(indiceVecTy, *targetShape);
     SmallVector<Value> convertedIndiceVec =

>From 402c015122d0cd2de8db895f0b052f387976a7cb Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 14 Jun 2025 01:32:46 +0000
Subject: [PATCH 10/10] clang format fix

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 11cb2ed478de5..9c234c1e866b9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -415,7 +415,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
 
     TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
     VectorType indiceVecTy = indiceVec.getType();
-    
+
     SmallVector<Type> convertedIndiceTypes =
         getUnrolledTypes(indiceVecTy, *targetShape);
     SmallVector<Value> convertedIndiceVec =



More information about the Mlir-commits mailing list