[Mlir-commits] [mlir] [MLIR][XeGPU] Extend Wg-to-Sg Distribution of Multi-Reduction Op for round-robin layout (PR #189988)

Charitha Saumya llvmlistbot at llvm.org
Mon Apr 6 13:39:39 PDT 2026


================
@@ -165,6 +165,58 @@ gpu.module @test_distribution {
     gpu.return
   }
 
+  // CHECK-LABEL: gpu.func @reduction_cross_sg_rr
+  gpu.func @reduction_cross_sg_rr(%arg0: memref<2048xf32, 1>) kernel {
+    // CHECK: %[[CST_OFFSETS0:.*]] = arith.constant dense<0> : vector<4x16xindex>
+    // CHECK: %[[CST_OFFSETS1:.*]] = arith.constant dense<0> : vector<4x16xindex>
+    // CHECK: %[[CST_ACC0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[CST_ACC1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[CST_MASK0:.*]] = arith.constant dense<true> : vector<4x16xi1>
+    // CHECK: %[[CST_MASK1:.*]] = arith.constant dense<true> : vector<4x16xi1>
+    //
+    // CHECK: %[[LOAD0:.*]] = xegpu.load %arg0[%[[CST_OFFSETS0]]], %[[CST_MASK0]]
+    // CHECK-SAME: -> vector<4x16xf32>
+    // CHECK: %[[LOAD1:.*]] = xegpu.load %arg0[%[[CST_OFFSETS1]]], %[[CST_MASK1]]
+    // CHECK-SAME: -> vector<4x16xf32>
+    //
+    // Local reductions
+    // CHECK: %[[NEUTRAL0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[LOCAL_RED0:.*]] = vector.multi_reduction <add>, %[[LOAD0]], %[[NEUTRAL0]] [1] : vector<4x16xf32> to vector<4xf32>
+    // CHECK: %[[NEUTRAL1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[LOCAL_RED1:.*]] = vector.multi_reduction <add>, %[[LOAD1]], %[[NEUTRAL1]] [1] : vector<4x16xf32> to vector<4xf32>
+    //
+    // Shape cast for SLM store
+    // CHECK: %[[SC0:.*]] = vector.shape_cast %[[LOCAL_RED0]] : vector<4xf32> to vector<4x1xf32>
+    // CHECK: %[[SC1:.*]] = vector.shape_cast %[[LOCAL_RED1]] : vector<4xf32> to vector<4x1xf32>
+    //
+    // SLM allocation and mem_desc
+    // CHECK: %[[SLM:.*]] = memref.alloca() : memref<512xi8, 3>
+    // CHECK: %[[MEMDESC:.*]] = xegpu.create_mem_desc %[[SLM]] : memref<512xi8, 3> -> !xegpu.mem_desc<8x16xf32>
+    //
+    // Store to SLM
+    // CHECK: xegpu.store_matrix %[[SC0]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
+    // CHECK: xegpu.store_matrix %[[SC1]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
+    // CHECK: gpu.barrier
+    //
+    // Load from SLM
+    // CHECK: %[[SLM_LOAD0:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
+    // CHECK: %[[SLM_LOAD1:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
+    //
+    // Final reduction
+    // CHECK: %[[FINAL_NEUTRAL:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
+    // CHECK: %[[FINAL_RED0:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD0]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
+    // CHECK: %[[RES0:.*]] = arith.addf %[[FINAL_RED0]], %[[CST_ACC0]] : vector<4xf32>
+    // CHECK: %[[FINAL_RED1:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD1]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
+    // CHECK: %[[RES1:.*]] = arith.addf %[[FINAL_RED1]], %[[CST_ACC1]] : vector<4xf32>
+
+    %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<0> : vector<8x256xindex>
+    %acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} dense<0.000000e+00> : vector<8xf32>
+    %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<true> : vector<8x256xi1>
+    %val = xegpu.load %arg0[%offset], %mask <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
+    %reduce = vector.multi_reduction <add>, %val, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
----------------
charithaintc wrote:

what happens if the RR distribution is in the reduction dim? do we also handle that case?

https://github.com/llvm/llvm-project/pull/189988


More information about the Mlir-commits mailing list