[Mlir-commits] [mlir] [MLIR][XeGPU] Add handling for unit-dim expansion in ShapeCast workgroup-to-subgroup distribution (PR #171758)
Charitha Saumya
llvmlistbot at llvm.org
Tue Dec 16 10:11:39 PST 2025
================
@@ -380,4 +380,21 @@ gpu.module @test_1_1_assignment {
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
gpu.return
}
+
+ // CHECK-LABEL: distribute_shapecast_expandunitdims_broadcast
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[REDUCE:.*]] {layout_result_0 = #xegpu.layout<inst_data = [8, 1]>} : vector<8xf32> to vector<8x1xf32>
+ // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] {layout_result_0 = #xegpu.layout<inst_data = [8, 16]>} : vector<8x1xf32> to vector<8x128xf32>
+ gpu.func @distribute_shapecast_expandunitdims_broadcast(%arg0: memref<4096x128xf32>, %arg1: memref<4096x128xf32>) {
+ %cst_0 = arith.constant {layout_result_0=#xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} dense<0xFF800000> : vector<256xf32>
+ %block_id_x = gpu.block_id x
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
+ %1 = xegpu.load_nd %0[%block_id_x, 0] : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
+ %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
+ %3 = vector.shape_cast %2 {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
----------------
charithaintc wrote:
why not sg_data = [8, 1], inst_data = [8, 1]?
https://github.com/llvm/llvm-project/pull/171758
More information about the Mlir-commits
mailing list