[Mlir-commits] [mlir] [MLIR][XeGPU] Enable blocking for scatter ops with offsets (PR #162896)

Nishant Patel llvmlistbot at llvm.org
Fri Oct 10 10:39:39 PDT 2025


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/162896

The unroll patterns for these ops were added in the previous PR but the getTileShape method was not changed to handle these ops and hence blocking pass was not kicking in.

>From 8941125bae0c707a2a43990049c53563ef92c603 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Fri, 10 Oct 2025 17:07:35 +0000
Subject: [PATCH] Enable blocking for scatter ops with offsets

---
 .../XeGPU/Transforms/XeGPUBlocking.cpp        |  17 ++-
 mlir/test/Dialect/XeGPU/xegpu-blocking.mlir   | 102 ++++++++++++++++++
 2 files changed, 117 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 36c498e8b849d..f77784abaf0b2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
           xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
     return getTileShape(op->getOpResult(0));
   if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
-          xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
+          xegpu::StoreMatrixOp>(op))
     return getTileShape(op->getOpOperand(0));
-  if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
+  if (isa<xegpu::StoreNdOp>(op))
     return getTileShape(op->getOpOperand(1));
 
+  // Handle LoadGatherOp and StoreScatterOp (with and without offset)
+  if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+    if (loadGatherOp.getOffsets())
+      return getTileShape(loadGatherOp->getOpResult(0));
+    else
+      return getTileShape(loadGatherOp->getOpOperand(0));
+  }
+
+  if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+    return getTileShape(storeScatterOp.getOffsets()
+                            ? storeScatterOp->getOpOperand(0)
+                            : storeScatterOp->getOpOperand(1));
+
   if (isa<xegpu::DpasOp>(op)) {
     std::optional<SmallVector<int64_t>> aTile =
         getTileShape(op->getOpOperand(0));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 9d63c2ddd4895..ce4db475ab249 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -584,3 +584,105 @@ gpu.module @test_kernel {
     gpu.return
   }
 }
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: load_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  gpu.func @load_with_offsets(%src: ui64) -> vector<32xf32> {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+      ]> : vector<32xindex>
+
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+      %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+
+      gpu.return %ld : vector<32xf32>
+  }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: store_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
+  gpu.func @store_with_offsets(%src: ui64) {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+      ]> : vector<32xindex>
+
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %st_vec = arith.constant dense<1023.0>: vector<32xf32>
+      xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [16]>, 
+                                              layout_operand_2 = #xegpu.layout<inst_data = [16]>, 
+                                              layout_operand_3 = #xegpu.layout<inst_data = [16]>, 
+                                              l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1>
+
+      gpu.return
+  }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: load_with_offsets_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
+  // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+  // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+  // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK-COUNT-4: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
+   gpu.func @load_with_offsets_chunk(%src: ui64) -> vector<32x4xf32> {
+    %cst = arith.constant dense<[
+        0,   8,  16,  24,  32,  40,  48,  56,
+        64,  72,  80,  88,  96, 104, 112, 120,
+        128, 136, 144, 152, 160, 168, 176, 184,
+        192, 200, 208, 216, 224, 232, 240, 248
+    ]> : vector<32xindex>
+
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+    %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
+    gpu.return %ld : vector<32x4xf32>
+   }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: store_with_offsets_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
+  // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+  // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+  // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK-COUNT-4: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, ui64, vector<16xindex>, vector<16xi1>
+  gpu.func @store_with_offsets_chunk(%src: ui64) {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+    ]> : vector<32xindex>
+
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %st_vec = arith.constant dense<1023.>: vector<32x4xf32>
+    xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout_operand_0 = #xegpu.layout<inst_data = [16, 2]>,
+                                            layout_operand_2 = #xegpu.layout<inst_data = [16, 2]>,
+                                            layout_operand_3 = #xegpu.layout<inst_data = [16, 2]>,
+                                            l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, ui64, vector<32xindex>, vector<32xi1>
+    gpu.return
+  }
+}



More information about the Mlir-commits mailing list