[Mlir-commits] [mlir] [MLIR][XeGPU] Enable blocking for scatter ops with offsets (PR #162896)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Fri Oct 10 10:40:15 PDT 2025
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
The unroll patterns for these ops were added in the previous PR but the getTileShape method was not changed to handle these ops and hence blocking pass was not kicking in.
---
Full diff: https://github.com/llvm/llvm-project/pull/162896.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp (+15-2) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-blocking.mlir (+102) 
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 36c498e8b849d..f77784abaf0b2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -161,11 +161,24 @@ XeGPUBlockingPass::getTileShape(Operation *op) const {
           xegpu::UpdateOffsetOp, xegpu::LoadMatrixOp>(op))
     return getTileShape(op->getOpResult(0));
   if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp, xegpu::PrefetchOp,
-          xegpu::LoadGatherOp, xegpu::StoreMatrixOp>(op))
+          xegpu::StoreMatrixOp>(op))
     return getTileShape(op->getOpOperand(0));
-  if (isa<xegpu::StoreNdOp, xegpu::StoreScatterOp>(op))
+  if (isa<xegpu::StoreNdOp>(op))
     return getTileShape(op->getOpOperand(1));
 
+  // Handle LoadGatherOp and StoreScatterOp (with and without offset)
+  if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(op)) {
+    if (loadGatherOp.getOffsets())
+      return getTileShape(loadGatherOp->getOpResult(0));
+    else
+      return getTileShape(loadGatherOp->getOpOperand(0));
+  }
+
+  if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
+    return getTileShape(storeScatterOp.getOffsets()
+                            ? storeScatterOp->getOpOperand(0)
+                            : storeScatterOp->getOpOperand(1));
+
   if (isa<xegpu::DpasOp>(op)) {
     std::optional<SmallVector<int64_t>> aTile =
         getTileShape(op->getOpOperand(0));
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 9d63c2ddd4895..ce4db475ab249 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -584,3 +584,105 @@ gpu.module @test_kernel {
     gpu.return
   }
 }
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: load_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16xf32>
+  gpu.func @load_with_offsets(%src: ui64) -> vector<32xf32> {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+      ]> : vector<32xindex>
+
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+      %ld = xegpu.load %src[%cst], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<inst_data = [16]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+
+      gpu.return %ld : vector<32xf32>
+  }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: store_with_offsets
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-2: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16xf32>, ui64, vector<16xindex>, vector<16xi1>
+  gpu.func @store_with_offsets(%src: ui64) {
+      %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+      ]> : vector<32xindex>
+
+      %c17 = arith.constant 17: index
+      %mask = vector.create_mask %c17: vector<32xi1>
+
+      %st_vec = arith.constant dense<1023.0>: vector<32xf32>
+      xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 1, layout_operand_0 = #xegpu.layout<inst_data = [16]>, 
+                                              layout_operand_2 = #xegpu.layout<inst_data = [16]>, 
+                                              layout_operand_3 = #xegpu.layout<inst_data = [16]>, 
+                                              l1_hint = #xegpu.cache_hint<cached>} : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1>
+
+      gpu.return
+  }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: load_with_offsets_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32x4xf32>
+  // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+  // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+  // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK-COUNT-4: xegpu.load  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<16xindex>, vector<16xi1> -> vector<16x2xf32>
+   gpu.func @load_with_offsets_chunk(%src: ui64) -> vector<32x4xf32> {
+    %cst = arith.constant dense<[
+        0,   8,  16,  24,  32,  40,  48,  56,
+        64,  72,  80,  88,  96, 104, 112, 120,
+        128, 136, 144, 152, 160, 168, 176, 184,
+        192, 200, 208, 216, 224, 232, 240, 248
+    ]> : vector<32xindex>
+
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+    %ld = xegpu.load %src[%cst], %mask {chunk_size = 4, layout_result_0 = #xegpu.layout<inst_data = [16, 2]>, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<32xindex>, vector<32xi1> -> vector<32x4xf32>
+    gpu.return %ld : vector<32x4xf32>
+   }
+}
+
+// -----
+gpu.module @test_kernel {
+  // CHECK-LABEL: store_with_offsets_chunk
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK: [[cst:%.+]] = arith.constant dense<1.023000e+03> : vector<16x2xf32
+  // CHECK: [[cst0:%.+]] = arith.constant dense<[130, 138, 146, 154, 162, 170, 178, 186, 194, 202, 210, 218, 226, 234, 242, 250]> : vector<16xindex>
+  // CHECK: [[cst1:%.+]] = arith.constant dense<[2, 10, 18, 26, 34, 42, 50, 58, 66, 74, 82, 90, 98, 106, 114, 122]> : vector<16xindex>
+  // CHECK: [[cst2:%.+]] = arith.constant dense<[128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<16xindex>
+  // CHECK: [[cst3:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120]> : vector<16xindex>
+  // CHECK-COUNT-4: xegpu.store  {{.*}}[{{.*}}], {{.*}} <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<16x2xf32>, ui64, vector<16xindex>, vector<16xi1>
+  gpu.func @store_with_offsets_chunk(%src: ui64) {
+    %cst = arith.constant dense<[
+      0,   8,  16,  24,  32,  40,  48,  56,
+      64,  72,  80,  88,  96, 104, 112, 120,
+      128, 136, 144, 152, 160, 168, 176, 184,
+      192, 200, 208, 216, 224, 232, 240, 248
+    ]> : vector<32xindex>
+
+    %c17 = arith.constant 17: index
+    %mask = vector.create_mask %c17: vector<32xi1>
+
+    %st_vec = arith.constant dense<1023.>: vector<32x4xf32>
+    xegpu.store %st_vec, %src[%cst], %mask {chunk_size = 4, layout_operand_0 = #xegpu.layout<inst_data = [16, 2]>,
+                                            layout_operand_2 = #xegpu.layout<inst_data = [16, 2]>,
+                                            layout_operand_3 = #xegpu.layout<inst_data = [16, 2]>,
+                                            l1_hint = #xegpu.cache_hint<cached>} : vector<32x4xf32>, ui64, vector<32xindex>, vector<32xi1>
+    gpu.return
+  }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/162896
    
    
More information about the Mlir-commits
mailing list