[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for cross-subgroup reduction from wg to sg (PR #170936)
Jianhui Li
llvmlistbot at llvm.org
Sat Dec 20 10:26:36 PST 2025
================
@@ -634,6 +642,137 @@ gpu.module @test_distribution {
gpu.return
}
+ // CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_1
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
+ gpu.func @vector_reduce_cross_sg_dim_1(%src: memref<?xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x32xf32>
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32xindex>
+ // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32xi1>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+ // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+ // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [1] : vector<1x1x32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map2()[%[[SGID]]]
+ // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL2]] : index
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL3]] : index
+ // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
+ // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<32x32xf32> to vector<32xf32>
+ // CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x32xf32> to vector<32xf32>
+ // CHECK-DAG: %{{.*}} = arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<32xf32>
+ // CHECK-DAG: gpu.return
+ %cst_3 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} dense<1.0> : vector<1x32xf32>
+ %offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<0> : vector<1x32x32xindex>
+ %mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<true> : vector<1x32x32xi1>
+ %14 = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} : memref<?xf32>, vector<1x32x32xindex>, vector<1x32x32xi1> -> vector<1x32x32xf32>
+ %15 = vector.multi_reduction <add>, %14, %cst_3 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} [1] : vector<1x32x32xf32> to vector<1x32xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_0
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<256x128xf32>)
+ gpu.func @vector_reduce_cross_sg_dim_0(%src: memref<256x128xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[REM1:.*]] = arith.remui %[[SGID]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[DIV1:.*]] = arith.divui %[[SGID]], %[[C4:.*]] : index
+ // CHECK-DAG: %[[REM2:.*]] = arith.remui %[[DIV1]], %[[C8:.*]] : index
+ // CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[REM2]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM1]], %[[C32:.*]] : index
+ // CHECK-DAG: %[[REM3:.*]] = arith.remui %[[MUL1]], %[[C256:.*]] : index
+ // CHECK-DAG: %[[REM4:.*]] = arith.remui %[[MUL2]], %[[C128:.*]] : index
+ // CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[REM3]], %[[REM4]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
+ // CHECK-DAG: %[[LOAD_ND:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32> -> vector<32x32xf32>
+ // CHECK-DAG: %[[CST_LOCAL:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_ND]], %[[CST_LOCAL]] [0] : vector<32x32xf32> to vector<32xf32>
+ // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<32xf32> to vector<1x32xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
+ // CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map3()[%[[SGID2]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map4()[%[[SGID2]]]
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
+ // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL4]] : index
+ // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD1]], %[[C32:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
+ // CHECK-DAG: gpu.barrier
+ // CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
+ // CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
+ // CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST:.*]] : vector<32xf32>
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<0.0> : vector<128xf32>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+ -> vector<256x128xf32>
+ %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
+ : vector<256x128xf32> to vector<128xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: gpu.func @vector_reduce_multi_dim
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
+ gpu.func @vector_reduce_multi_dim(%src: memref<?xf32>) {
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32x32xindex>
+ // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32x32xi1>
+ // CHECK-DAG: %[[LOAD:.*]] = xegpu.load %{{.*}}[%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32x32xindex>, vector<1x1x32x32xi1> -> vector<1x1x32x32xf32>
+ // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
+ // CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [2, 3] : vector<1x1x32x32xf32> to vector<1x1xf32>
+ // CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x1xf32> to vector<1x1xf32>
+ // CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi8, 3>
+ // CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<16x4xf32>
+ // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+ // CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+ // CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map5()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE4:.*]] = affine.apply #map6()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE5:.*]] = affine.apply #map7()[%[[SGID]]]
+ // CHECK-DAG: %[[AFFINE6:.*]] = affine.apply #map4()[%[[SGID]]]
+ // CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
+ // CHECK-DAG: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4:.*]] : index
+ // CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
+ // CHECK-DAG: %[[MUL3:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
+ // CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[C0:.*]], %[[MUL3]] : index
+ // CHECK-DAG: %[[MUL4:.*]] = arith.muli {{.*}}, %[[C2:.*]] : index
+ // CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[ADD2]], %[[MUL4]] : index
+ // CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD3]], %[[C1:.*]] : index
+ // CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<16x4xf32>, index, index
----------------
Jianhui-Li wrote:
Store_matrix should assoicate with a slice_layout. in this case, it should be
#slice{#xegpu_layout={}, dims=[2, 3]}. It doesn't have to decide the inst_data inside the parent layout, but it is important to keep the slice structure.
The slice attribute ensures that the layout used by store_matrix matches with sg local reduction result. In case of the sg local reduction result is reduced across lanes and thus implicitly broadcasted to all lanes, the result will be assoicated with the slice layout to indicate such implicit broadcast effects. In this case, there is no need for each lane to write to the share local memory for reduced data. Instead, the store matrix should use the slice layout to direct the subgroup distribution process, which then only elect one lane to save the local reduction result to share local memory.
https://github.com/llvm/llvm-project/pull/170936
More information about the Mlir-commits
mailing list