[Mlir-commits] [mlir] [MLIR][XeGPU] Add support for vector.multi_reduction in wg to sg pass [1/N] (PR #157554)
Jianhui Li
llvmlistbot at llvm.org
Fri Sep 12 14:00:25 PDT 2025
================
@@ -365,4 +365,32 @@ gpu.module @test_distribution {
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
gpu.return
}
+
+ // CHECK-LABEL: @vector_reduce_dim_0
+ gpu.func @vector_reduce_dim_0(%src: memref<1x128xf32>) {
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} dense<1.0> : vector<128xf32>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
+ -> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
+ -> vector<1x128xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [0] : vector<1x4xf32> to vector<4xf32>
+ %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} [0]
+ : vector<1x128xf32> to vector<128xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: @vector_reduce_dim_1
+ gpu.func @vector_reduce_dim_1(%src: memref<256x1xf32>) {
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} dense<1.0> : vector<256xf32>
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
+ -> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
+ -> vector<256x1xf32>
+ // CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<8x1xf32> to vector<8xf32>
+ %reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} [1]
+ : vector<256x1xf32> to vector<256xf32>
----------------
Jianhui-Li wrote:
I think that the shape should be <256x32>.
https://github.com/llvm/llvm-project/pull/157554
More information about the Mlir-commits
mailing list