[Mlir-commits] [mlir] [mlir][vector] Add extra check on distribute types to avoid crashes (PR #102952)

Han-Chung Wang llvmlistbot at llvm.org
Mon Aug 12 13:13:24 PDT 2024


hanhanW wrote:

(wait, where is my comment! I swear I wrote something down...)

According to Ian's log, it failed in VectorReductionToGPUPass. Can we trim IREE specifics from the IR and add a test to [vector-warp-distribute.mlir](https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/Vector/vector-warp-distribute.mlir)?


```mlir
// -----// IR Dump Before VectorReductionToGPUPass (iree-codegen-vector-reduction-to-gpu) //----- //
func.func @main$async_dispatch_6_generic_32x262144_f16xf32xf32xf32() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUWarpReduction workgroup_size = [1024, 1, 1] subgroup_size = 64>} {
  %cst = arith.constant dense<0.000000e+00> : vector<4096xf32>
  %cst_0 = arith.constant dense<9.99999997E-7> : vector<1xf32>
  %cst_1 = arith.constant dense<2.621440e+05> : vector<1xf32>
  %cst_2 = arith.constant dense<2.621440e+05> : vector<4096xf32>
  %cst_3 = arith.constant 0.000000e+00 : f16
  %cst_4 = arith.constant dense<0.000000e+00> : vector<1xf32>
  %c262144 = arith.constant 262144 : index
  %c16384 = arith.constant 16384 : index
  %c16 = arith.constant 16 : index
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %c4096 = arith.constant 4096 : index
  %0 = hal.interface.constant.load[0] : i32
  %1 = hal.interface.constant.load[1] : i32
  %2 = arith.index_castui %0 : i32 to index
  %3 = arith.index_castui %1 : i32 to index
  %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %4, 1 : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %5 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %5, 1 : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %6 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%3) : memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  memref.assume_alignment %6, 1 : memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %7 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %cst) -> (vector<4096xf32>) {
    %19 = scf.for %arg2 = %c0 to %c16384 step %c4096 iter_args(%arg3 = %arg1) -> (vector<4096xf32>) {
      %20 = vector.transfer_read %4[%workgroup_id_x, %arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4096xf16>
      %21 = arith.extf %20 : vector<4096xf16> to vector<4096xf32>
      %22 = arith.addf %21, %arg3 : vector<4096xf32>
      scf.yield %22 : vector<4096xf32>
    }
    scf.yield %19 : vector<4096xf32>
  }
  %8 = vector.broadcast %7 : vector<4096xf32> to vector<1x1x4096xf32>
  %9 = vector.multi_reduction <add>, %8, %cst_4 [1, 2] : vector<1x1x4096xf32> to vector<1xf32>
  %10 = vector.broadcast %9 : vector<1xf32> to vector<4096xf32>
  %11 = arith.divf %10, %cst_2 : vector<4096xf32>
  %12 = scf.for %arg0 = %c0 to %c16 step %c1 iter_args(%arg1 = %cst) -> (vector<4096xf32>) {
    %19 = scf.for %arg2 = %c0 to %c16384 step %c4096 iter_args(%arg3 = %arg1) -> (vector<4096xf32>) {
      %20 = vector.transfer_read %4[%workgroup_id_x, %arg0, %arg2], %cst_3 {in_bounds = [true]} : memref<32x16x16384xf16, strided<[262144, 16384, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<4096xf16>
      %21 = arith.extf %20 : vector<4096xf16> to vector<4096xf32>
      %22 = arith.subf %21, %11 : vector<4096xf32>
      %23 = arith.mulf %22, %22 : vector<4096xf32>
      %24 = arith.addf %23, %arg3 : vector<4096xf32>
      scf.yield %24 : vector<4096xf32>
    }
    scf.yield %19 : vector<4096xf32>
  }
  %13 = vector.broadcast %12 : vector<4096xf32> to vector<1x1x4096xf32>
  %14 = vector.multi_reduction <add>, %13, %cst_4 [1, 2] : vector<1x1x4096xf32> to vector<1xf32>
  %15 = arith.divf %9, %cst_1 : vector<1xf32>
  %16 = arith.divf %14, %cst_1 : vector<1xf32>
  %17 = arith.addf %16, %cst_0 : vector<1xf32>
  %18 = math.rsqrt %17 : vector<1xf32>
  scf.for %arg0 = %c0 to %c262144 step %c1 {
    %19 = vector.transfer_read %5[%workgroup_id_x, %arg0], %cst_3 {in_bounds = [true]} : memref<32x262144xf16, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>, vector<1xf16>
    %20 = arith.extf %19 : vector<1xf16> to vector<1xf32>
    %21 = arith.subf %20, %15 : vector<1xf32>
    %22 = arith.mulf %21, %18 : vector<1xf32>
    vector.transfer_write %22, %6[%workgroup_id_x, %arg0] {in_bounds = [true]} : vector<1xf32>, memref<32x262144xf32, strided<[262144, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  }
  return
}
```


https://github.com/llvm/llvm-project/pull/102952


More information about the Mlir-commits mailing list