[Mlir-commits] [mlir] [MLIR][XEGPU] Add blocking support for scatter ops (PR #144766)

Charitha Saumya llvmlistbot at llvm.org
Wed Jun 18 14:51:50 PDT 2025


https://github.com/charithaintc updated https://github.com/llvm/llvm-project/pull/144766

>From 1db8a64be353d24267d82d87f8a4a050c309bf98 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Jun 2025 16:39:41 +0000
Subject: [PATCH 1/7] add blocking support for scatter ops

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        | 41 +++++++++++++++----
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   | 18 ++++----
 2 files changed, 41 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index a3826c56e1f62..4d96e976ee4b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -134,11 +134,11 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
 
 std::optional<SmallVector<int64_t>>
 XeGPUBlockingPass::getTileShape(Operation *op) const {
-  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
+  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp, xegpu::UpdateOffsetOp>(op))
     return getTileShape(op->getOpResult(0));
-  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
+  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp, xegpu::LoadGatherOp>(op))
     return getTileShape(op->getOpOperand(0));
-  if (isa<xegpu::StoreNdOp>(op))
+  if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
     return getTileShape(op->getOpOperand(1));
 
   if (isa<xegpu::DpasOp>(op)) {
@@ -295,13 +295,40 @@ void XeGPUBlockingPass::runOnOperation() {
     Type elemTy = type.getElementType();
     Type newTy;
 
-    if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
+    if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
+
+      Attribute encoding = tdescTy.getEncoding();
+      // If the encoding is a ScatterTensorDescAttr, we need to
+      // potentially adjust the chunk size based on the inst_data.
+      if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
+        auto scatterAttr =
+            mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+        int64_t chunkSize = scatterAttr.getChunkSize().getInt();
+
+        if (chunkSize > 1) {
+          int64_t blockedChunkSize = chunkSize;
+          auto instData = tdescTy.getLayoutAttr().getInstData();
+          if (!instData.empty())
+            blockedChunkSize = instData.asArrayRef().back();
+
+          auto chunkSizeAttr = mlir::IntegerAttr::get(
+              mlir::IntegerType::get(ctx, 64), blockedChunkSize);
+
+          // To create a new attribute with a different chunk_size:
+          auto newEncoding = xegpu::ScatterTensorDescAttr::get(
+              ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+
+          encoding = newEncoding;
+        }
+      }
+
       newTy = xegpu::TensorDescType::get(
-          ctx, tileShape, elemTy, tdescTy.getEncoding(),
+          ctx, tileShape, elemTy, encoding,
           tdescTy.getLayoutAttr().dropInstData());
-    else
+    } else {
       newTy = type.clone(tileShape, elemTy);
-
+    }
+    
     std::optional<SmallVector<int64_t>> ratio =
         computeShapeRatio(type.getShape(), tileShape);
     assert(ratio && "The shape of the type must be a multiple of tileShape.");
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 67d3bd9b393c0..ade204e588e01 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -250,8 +250,7 @@ gpu.module @test_kernel {
 // -----
 #l = #xegpu.layout<inst_data = [16, 16]>
 #r = #xegpu.layout<inst_data = [16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
   gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
     %acc = arith.constant dense<0.0> : vector<64xf32>
     %c64 = arith.constant 64 : index
@@ -271,8 +270,7 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
 // -----
 #l = #xegpu.layout<inst_data = [16, 16]>
 #r = #xegpu.layout<inst_data = [16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
   gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
     %c1 = arith.constant 1 : index
     %c32 = arith.constant 32 : index
@@ -299,8 +297,7 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
 // -----
 #r = #xegpu.layout<inst_data = [16]>
 #l = #xegpu.layout<inst_data = [16, 16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
   gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
 
     %c64 = arith.constant 64 : index
@@ -319,8 +316,7 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
 // -----
 #r = #xegpu.layout<inst_data = [16]>
 #l = #xegpu.layout<inst_data = [16, 16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
   gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
 
     %c32 = arith.constant 32 : index
@@ -340,8 +336,7 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
 // -----
 #l = #xegpu.layout<inst_data = [16, 8]>
 #t = #xegpu.layout<inst_data = [8, 16]>
-
-gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
   gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
 
     %c32 = arith.constant 32 : index
@@ -355,4 +350,5 @@ gpu.module @kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<
     xegpu.store_nd %2, %3: vector<8x32xf32>, !xegpu.tensor_desc<8x32xf32, #t>
     gpu.return
   }
-}
\ No newline at end of file
+}
+

>From 1e6c2f3eb166efdec39c037b2b6f46c4093e7567 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Jun 2025 17:37:03 +0000
Subject: [PATCH 2/7] ading test

---
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir | 105 +++++++++++++++++++-
 1 file changed, 100 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index ade204e588e01..f977ba3c11bcf 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -250,7 +250,7 @@ gpu.module @test_kernel {
 // -----
 #l = #xegpu.layout<inst_data = [16, 16]>
 #r = #xegpu.layout<inst_data = [16]>
-gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel  {
   gpu.func @reduce_dim_0(%a: memref<16x512xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
     %acc = arith.constant dense<0.0> : vector<64xf32>
     %c64 = arith.constant 64 : index
@@ -270,7 +270,7 @@ gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv
 // -----
 #l = #xegpu.layout<inst_data = [16, 16]>
 #r = #xegpu.layout<inst_data = [16]>
-gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel   {
   gpu.func @reduce_dim_1(%a: memref<512x32xf32>, %b: memref<512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
     %c1 = arith.constant 1 : index
     %c32 = arith.constant 32 : index
@@ -297,7 +297,7 @@ gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv
 // -----
 #r = #xegpu.layout<inst_data = [16]>
 #l = #xegpu.layout<inst_data = [16, 16]>
-gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel   {
   gpu.func @broadcast_dim_0(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
 
     %c64 = arith.constant 64 : index
@@ -316,7 +316,7 @@ gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv
 // -----
 #r = #xegpu.layout<inst_data = [16]>
 #l = #xegpu.layout<inst_data = [16, 16]>
-gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel  {
   gpu.func @broadcast_dim_1(%a: memref<512xf32>, %b: memref<16x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
 
     %c32 = arith.constant 32 : index
@@ -336,7 +336,7 @@ gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv
 // -----
 #l = #xegpu.layout<inst_data = [16, 8]>
 #t = #xegpu.layout<inst_data = [8, 16]>
-gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
+gpu.module @test_kernel   {
   gpu.func @transpose(%a: memref<512x8xf32>, %b: memref<8x512xf32>)  kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
 
     %c32 = arith.constant 32 : index
@@ -352,3 +352,98 @@ gpu.module @test_kernel  attributes {spirv.target_env = #spirv.target_env<#spirv
   }
 }
 
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: test_prefetch_load_store_update
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+   // CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
+   // CHECK-COUNT-2: xegpu.load  {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
+  // CHECK-COUNT-2: xegpu.store  {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
+
+  gpu.func @test_prefetch_load_store_update(%src: ui64)  {
+
+    %cst = arith.constant dense<[
+    0,   8,  16,  24,  32,  40,  48,  56,
+    64,  72,  80,  88,  96, 104, 112, 120,
+    128, 136, 144, 152, 160, 168, 176, 184,
+    192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32,  #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
+   
+    %delta = arith.constant dense<[
+    32,   32,  32,  32,  32,  32,  32,  32,
+    32,   32,  32,  32,  32,  32,  32,  64,
+    128, 128, 128, 128, 128, 128, 128, 128,
+    128, 128, 128, 128, 128, 128, 128, 256 
+    ]> : vector<32xindex>
+    %new_tdesc = xegpu.update_offset %tdesc, %delta
+              : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>     
+ 
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
+
+    %st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
+    xegpu.store %st_vec, %tdesc, %mask: 
+                 vector<32xf32>, 
+                 !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, 
+                 vector<32xi1>
+  
+    gpu.return
+  }
+  
+}
+
+// -----
+
+gpu.module @test_kernel   {
+  // CHECK-LABEL: test_prefetch_load_store_update_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+   // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xindex>
+   // CHECK-COUNT-4: xegpu.load  {{.*}} : !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1> -> vector<2x16xf32>
+  // CHECK-COUNT-4: xegpu.store  {{.*}} : vector<2x16xf32>, !xegpu.tensor_desc<16x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<16xi1>
+
+  gpu.func @test_prefetch_load_store_update_chunk(%src: ui64)  {
+
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248 
+    ]> : vector<32xindex>
+
+    %tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+    xegpu.prefetch %tdesc: !xegpu.tensor_desc<32x4xf32,  #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>
+   
+    %delta = arith.constant dense<[
+      32,   32,  32,  32,  32,  32,  32,  32,
+      32,   32,  32,  32,  32,  32,  32,  64,
+      128, 128, 128, 128, 128, 128, 128, 128,
+      128, 128, 128, 128, 128, 128, 128, 256 
+    ]> : vector<32xindex>
+    %new_tdesc = xegpu.update_offset %tdesc, %delta
+              : !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xindex>     
+ 
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %ld_vec = xegpu.load %new_tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, vector<32xi1> -> vector<4x32xf32>
+
+    %st_vec = arith.addf %ld_vec, %ld_vec : vector<4x32xf32>
+    xegpu.store %st_vec, %tdesc, %mask <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>: 
+                 vector<4x32xf32>, 
+                 !xegpu.tensor_desc<32x4xf32, #xegpu.scatter_tdesc_attr<chunk_size=4>, #xegpu.layout<inst_data = [16, 2]>>, 
+                 vector<32xi1>
+  
+    gpu.return
+  }
+}
+
+

>From 77fdcc1facbb4acc0ad85153212d7c5d8ee2d094 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Jun 2025 17:40:55 +0000
Subject: [PATCH 3/7] clang format fix

---
 .../lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 4d96e976ee4b4..02bfab8263847 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -134,9 +134,11 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
 
 std::optional<SmallVector<int64_t>>
 XeGPUBlockingPass::getTileShape(Operation *op) const {
-  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp, xegpu::UpdateOffsetOp>(op))
+  if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp, xegpu::CreateDescOp,
+          xegpu::UpdateOffsetOp>(op))
     return getTileShape(op->getOpResult(0));
-  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp, xegpu::LoadGatherOp>(op))
+  if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
+          xegpu::LoadGatherOp>(op))
     return getTileShape(op->getOpOperand(0));
   if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
     return getTileShape(op->getOpOperand(1));
@@ -322,13 +324,13 @@ void XeGPUBlockingPass::runOnOperation() {
         }
       }
 
-      newTy = xegpu::TensorDescType::get(
-          ctx, tileShape, elemTy, encoding,
-          tdescTy.getLayoutAttr().dropInstData());
+      newTy =
+          xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
+                                     tdescTy.getLayoutAttr().dropInstData());
     } else {
       newTy = type.clone(tileShape, elemTy);
     }
-    
+
     std::optional<SmallVector<int64_t>> ratio =
         computeShapeRatio(type.getShape(), tileShape);
     assert(ratio && "The shape of the type must be a multiple of tileShape.");

>From 3bb754b48be0e77f2af31ba819d84f31879e32c2 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Jun 2025 19:32:31 +0000
Subject: [PATCH 4/7] addresses feedback on API use

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 6 +++---
 mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 8 +++-----
 2 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 02bfab8263847..bf66df06d0b6b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -302,9 +302,9 @@ void XeGPUBlockingPass::runOnOperation() {
       Attribute encoding = tdescTy.getEncoding();
       // If the encoding is a ScatterTensorDescAttr, we need to
       // potentially adjust the chunk size based on the inst_data.
-      if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
-        auto scatterAttr =
-            mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+      if (tdescTy.isScattered()) {
+        auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
+        //    mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
         int64_t chunkSize = scatterAttr.getChunkSize().getInt();
 
         if (chunkSize > 1) {
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 4400d6d9625f7..d691067b9464b 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -102,14 +102,12 @@ struct TestXeGPUUnrollingPatterns
           // attribute
           if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type)) {
             Attribute encoding = tdescTy.getEncoding();
-            auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(
-                tdescTy.getLayout());
+            auto layout = tdescTy.getLayoutAttr();
 
             // If the encoding is a ScatterTensorDescAttr, we need to
             // potentially adjust the chunk size based on the inst_data.
-            if (encoding && mlir::isa<xegpu::ScatterTensorDescAttr>(encoding)) {
-              auto scatterAttr =
-                  mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+            if (tdescTy.isScattered()) {
+              auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
               int64_t chunkSize = scatterAttr.getChunkSize().getInt();
 
               if (chunkSize > 1) {

>From 915830cbf174cda80c251336a3696aa033d2f846 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Jun 2025 21:13:46 +0000
Subject: [PATCH 5/7] addresses feedbacks

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 8 ++------
 mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 5 +----
 2 files changed, 3 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index bf66df06d0b6b..8e38d8c73f794 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -303,8 +303,7 @@ void XeGPUBlockingPass::runOnOperation() {
       // If the encoding is a ScatterTensorDescAttr, we need to
       // potentially adjust the chunk size based on the inst_data.
       if (tdescTy.isScattered()) {
-        auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
-        //    mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
+        auto scatterAttr = llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);        
         int64_t chunkSize = scatterAttr.getChunkSize().getInt();
 
         if (chunkSize > 1) {
@@ -313,12 +312,9 @@ void XeGPUBlockingPass::runOnOperation() {
           if (!instData.empty())
             blockedChunkSize = instData.asArrayRef().back();
 
-          auto chunkSizeAttr = mlir::IntegerAttr::get(
-              mlir::IntegerType::get(ctx, 64), blockedChunkSize);
-
           // To create a new attribute with a different chunk_size:
           auto newEncoding = xegpu::ScatterTensorDescAttr::get(
-              ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+              ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
 
           encoding = newEncoding;
         }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index d691067b9464b..4aa5c4e8fdbca 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -116,12 +116,9 @@ struct TestXeGPUUnrollingPatterns
                 if (!instData.empty())
                   blockedChunkSize = instData.asArrayRef().back();
 
-                auto chunkSizeAttr = mlir::IntegerAttr::get(
-                    mlir::IntegerType::get(ctx, 64), blockedChunkSize);
-
                 // To create a new attribute with a different chunk_size:
                 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
-                    ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
+                    ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
 
                 encoding = newEncoding;
               }

>From beb5c8e047f38f6e1e5c6a5ddd62655495cd1a67 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Jun 2025 21:16:38 +0000
Subject: [PATCH 6/7] minor change

---
 mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 4aa5c4e8fdbca..d0879340693f0 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -107,7 +107,7 @@ struct TestXeGPUUnrollingPatterns
             // If the encoding is a ScatterTensorDescAttr, we need to
             // potentially adjust the chunk size based on the inst_data.
             if (tdescTy.isScattered()) {
-              auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
+              auto scatterAttr = llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);        
               int64_t chunkSize = scatterAttr.getChunkSize().getInt();
 
               if (chunkSize > 1) {

>From adf6358b9c46e20b49e5c8c1602fe567c06c81f8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 18 Jun 2025 21:17:41 +0000
Subject: [PATCH 7/7] clang format fix

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 3 ++-
 mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 7 +++++--
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 8e38d8c73f794..3950e8f70d1ca 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -303,7 +303,8 @@ void XeGPUBlockingPass::runOnOperation() {
       // If the encoding is a ScatterTensorDescAttr, we need to
       // potentially adjust the chunk size based on the inst_data.
       if (tdescTy.isScattered()) {
-        auto scatterAttr = llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);        
+        auto scatterAttr =
+            llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
         int64_t chunkSize = scatterAttr.getChunkSize().getInt();
 
         if (chunkSize > 1) {
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index d0879340693f0..c84eb74198544 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -107,7 +107,9 @@ struct TestXeGPUUnrollingPatterns
             // If the encoding is a ScatterTensorDescAttr, we need to
             // potentially adjust the chunk size based on the inst_data.
             if (tdescTy.isScattered()) {
-              auto scatterAttr = llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);        
+              auto scatterAttr =
+                  llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(
+                      encoding);
               int64_t chunkSize = scatterAttr.getChunkSize().getInt();
 
               if (chunkSize > 1) {
@@ -118,7 +120,8 @@ struct TestXeGPUUnrollingPatterns
 
                 // To create a new attribute with a different chunk_size:
                 auto newEncoding = xegpu::ScatterTensorDescAttr::get(
-                    ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
+                    ctx, scatterAttr.getMemorySpace().getValue(),
+                    blockedChunkSize);
 
                 encoding = newEncoding;
               }



More information about the Mlir-commits mailing list