[Mlir-commits] [mlir] [mlir][vector] Add extra check on distribute types to avoid crashes (PR #102952)
Han-Chung Wang
llvmlistbot at llvm.org
Mon Aug 12 13:13:24 PDT 2024
hanhanW wrote:
(wait, where is my comment! I swear I wrote something down...)
According to Ian's log, it failed in VectorReductionToGPUPass. Can we trim IREE specifics from the IR and add a test to [vector-warp-distribute.mlir](https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/vector-warp-distribute.mlir)?
```mlir
// -----// IR Dump Before VectorReductionToGPUPass (iree-codegen-vector-reduction-to-gpu) //----- //
func.func @main$async_dispatch_6_generic_32x262144_f16xf32xf32xf32() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUWarpReduction workgroup_size = [1024, 1, 1] subgroup_size = 64>} {
%cst = arith.constant dense<0.000000e+00> : vector<4096xf32>
%cst_0 = arith.constant dense<9.99999997E-7> : vector<1xf32>
%cst_1 = arith.constant dense<2.621440e+05> : vector<1xf32>
%cst_2 = arith.constant dense<2.621440e+05> : vector<4096xf32>
%cst_3 = arith.constant 0.000000e+00 : f16
%cst_4 = arith.constant dense<0.000000e+00> : vector<1xf32>
%c262144 = arith.constant 262144 : index
%c16384 = arith.constant 16384 : index
%c16 = arith.constant 16 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4096 = arith.constant 4096 : index
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = arith.index_castui %0 : i32 to index
%3 = arith.index_castui %1 : i32 to index
%4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %4, 1 : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %5, 1 : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%3) : memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %6, 1 : memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%7 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %cst) -> (vector<4096xf32>) {
%19 = scf.for %arg2 = %c0 to %c16384 step %c4096 iter_args(%arg3 = %arg1) -> (vector<4096xf32>) {
%20 = vector.transfer_read %4[%workgroup_id_x, %arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4096xf16>
%21 = arith.extf %20 : vector<4096xf16> to vector<4096xf32>
%22 = arith.addf %21, %arg3 : vector<4096xf32>
scf.yield %22 : vector<4096xf32>
}
scf.yield %19 : vector<4096xf32>
}
%8 = vector.broadcast %7 : vector<4096xf32> to vector<1x1x4096xf32>
%9 = vector.multi_reduction <add>, %8, %cst_4 [1, 2] : vector<1x1x4096xf32> to vector<1xf32>
%10 = vector.broadcast %9 : vector<1xf32> to vector<4096xf32>
%11 = arith.divf %10, %cst_2 : vector<4096xf32>
%12 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %cst) -> (vector<4096xf32>) {
%19 = scf.for %arg2 = %c0 to %c16384 step %c4096 iter_args(%arg3 = %arg1) -> (vector<4096xf32>) {
%20 = vector.transfer_read %4[%workgroup_id_x, %arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4096xf16>
%21 = arith.extf %20 : vector<4096xf16> to vector<4096xf32>
%22 = arith.subf %21, %11 : vector<4096xf32>
%23 = arith.mulf %22, %22 : vector<4096xf32>
%24 = arith.addf %23, %arg3 : vector<4096xf32>
scf.yield %24 : vector<4096xf32>
}
scf.yield %19 : vector<4096xf32>
}
%13 = vector.broadcast %12 : vector<4096xf32> to vector<1x1x4096xf32>
%14 = vector.multi_reduction <add>, %13, %cst_4 [1, 2] : vector<1x1x4096xf32> to vector<1xf32>
%15 = arith.divf %9, %cst_1 : vector<1xf32>
%16 = arith.divf %14, %cst_1 : vector<1xf32>
%17 = arith.addf %16, %cst_0 : vector<1xf32>
%18 = math.rsqrt %17 : vector<1xf32>
scf.for %arg0 = %c0 to %c262144 step %c1 {
%19 = vector.transfer_read %5[%workgroup_id_x, %arg0], %cst_3 {in_bounds = [true]} : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
%20 = arith.extf %19 : vector<1xf16> to vector<1xf32>
%21 = arith.subf %20, %15 : vector<1xf32>
%22 = arith.mulf %21, %18 : vector<1xf32>
vector.transfer_write %22, %6[%workgroup_id_x, %arg0] {in_bounds = [true]} : vector<1xf32>, memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
}
return
}
```
https://github.com/llvm/llvm-project/pull/102952
More information about the Mlir-commits
mailing list