[Mlir-commits] [mlir] [MLIR][GPU-LLVM] Add in-pass signature update option for opencl kernels (PR #105664)

Petr Kurapov llvmlistbot at llvm.org
Thu Sep 19 09:07:07 PDT 2024


kurapov-peter wrote:

@victor-eds, how about this a1666d65a95d639deca7e0514076b6cca025027f to remove the external patterns and avoid the `gpu.address_space` attribute to survive the lowering?

Here's how it looks like in an example. Input:
```mlir
gpu.module @kernels {
  gpu.func @no_address_spaces_complex(%arg0: memref<2x2xf32>, %arg1: memref<4xf32>) kernel {
    func.call @no_address_spaces_callee(%arg0, %arg1) : (memref<2x2xf32>, memref<4xf32>) -> ()
    gpu.return
  }

  func.func @no_address_spaces_callee(%arg0: memref<2x2xf32>, %arg1: memref<4xf32>) {
    %block_id = gpu.block_id x
    %0 = memref.load %arg0[%block_id, %block_id] : memref<2x2xf32>
    memref.store %0, %arg1[%block_id] : memref<4xf32>
    func.return
  }
}
```

After running the pass we get:
```mlir
module {
  gpu.module @kernels {
    llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
    llvm.func spir_kernelcc @no_address_spaces_complex(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) attributes {gpu.kernel} {
      %0 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
      %1 = llvm.insertvalue %arg7, %0[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %2 = llvm.insertvalue %arg8, %1[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %3 = llvm.insertvalue %arg9, %2[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %4 = llvm.insertvalue %arg10, %3[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %5 = llvm.insertvalue %arg11, %4[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %6 = builtin.unrealized_conversion_cast %5 : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> to memref<4xf32, 1>
      %7 = builtin.unrealized_conversion_cast %6 : memref<4xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
      %8 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
      %9 = llvm.insertvalue %arg0, %8[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %10 = llvm.insertvalue %arg1, %9[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %11 = llvm.insertvalue %arg2, %10[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %12 = llvm.insertvalue %arg3, %11[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %13 = llvm.insertvalue %arg5, %12[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %14 = llvm.insertvalue %arg4, %13[3, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %15 = llvm.insertvalue %arg6, %14[4, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %16 = builtin.unrealized_conversion_cast %15 : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> to memref<2x2xf32, 1>
      %17 = builtin.unrealized_conversion_cast %16 : memref<2x2xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
      func.call @no_address_spaces_callee(%16, %6) : (memref<2x2xf32, 1>, memref<4xf32, 1>) -> ()
      llvm.return
    }
    func.func @no_address_spaces_callee(%arg0: memref<2x2xf32, 1>, %arg1: memref<4xf32, 1>) {
      %0 = llvm.mlir.constant(0 : i32) : i32
      %1 = llvm.call spir_funccc @_Z12get_group_idj(%0) {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return} : (i32) -> i64
      %2 = builtin.unrealized_conversion_cast %1 : i64 to index
      %3 = memref.load %arg0[%2, %2] : memref<2x2xf32, 1>
      memref.store %3, %arg1[%2] : memref<4xf32, 1>
      return
    }
  }
}
```
So, there're no gpu ops left and the address spaces are correct, although described as plain integer attributes.

If I'm to lower that further to llvm I get all the unrealized casts resolved (`mlir-opt -pass-pipeline="builtin.module(gpu.module(convert-gpu-to-llvm-spv),convert-func-to-llvm,finalize-memref-to-llvm,reconcile-unrealized-casts)"`):
```mlir
module {
  gpu.module @kernels {
    llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
    llvm.func spir_kernelcc @no_address_spaces_complex(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) attributes {gpu.kernel} {
      %0 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
      %1 = llvm.insertvalue %arg7, %0[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %2 = llvm.insertvalue %arg8, %1[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %3 = llvm.insertvalue %arg9, %2[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %4 = llvm.insertvalue %arg10, %3[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %5 = llvm.insertvalue %arg11, %4[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %6 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
      %7 = llvm.insertvalue %arg0, %6[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %8 = llvm.insertvalue %arg1, %7[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %9 = llvm.insertvalue %arg2, %8[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %10 = llvm.insertvalue %arg3, %9[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %11 = llvm.insertvalue %arg5, %10[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %12 = llvm.insertvalue %arg4, %11[3, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %13 = llvm.insertvalue %arg6, %12[4, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %14 = llvm.extractvalue %13[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %15 = llvm.extractvalue %13[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %16 = llvm.extractvalue %13[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %17 = llvm.extractvalue %13[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %18 = llvm.extractvalue %13[3, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %19 = llvm.extractvalue %13[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %20 = llvm.extractvalue %13[4, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %21 = llvm.extractvalue %5[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %22 = llvm.extractvalue %5[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %23 = llvm.extractvalue %5[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %24 = llvm.extractvalue %5[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %25 = llvm.extractvalue %5[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      llvm.call @no_address_spaces_callee(%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25) : (!llvm.ptr<1>, !llvm.ptr<1>, i64, i64, i64, i64, i64, !llvm.ptr<1>, !llvm.ptr<1>, i64, i64, i64) -> ()
      llvm.return
    }
    llvm.func @no_address_spaces_callee(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) {
      %0 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
      %1 = llvm.insertvalue %arg7, %0[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %2 = llvm.insertvalue %arg8, %1[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %3 = llvm.insertvalue %arg9, %2[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %4 = llvm.insertvalue %arg10, %3[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %5 = llvm.insertvalue %arg11, %4[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %6 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
      %7 = llvm.insertvalue %arg0, %6[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %8 = llvm.insertvalue %arg1, %7[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %9 = llvm.insertvalue %arg2, %8[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %10 = llvm.insertvalue %arg3, %9[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %11 = llvm.insertvalue %arg5, %10[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %12 = llvm.insertvalue %arg4, %11[3, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %13 = llvm.insertvalue %arg6, %12[4, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %14 = llvm.mlir.constant(0 : i32) : i32
      %15 = llvm.call spir_funccc @_Z12get_group_idj(%14) {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return} : (i32) -> i64
      %16 = llvm.extractvalue %13[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> 
      %17 = llvm.mlir.constant(2 : index) : i64
      %18 = llvm.mul %15, %17 : i64
      %19 = llvm.add %18, %15 : i64
      %20 = llvm.getelementptr %16[%19] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
      %21 = llvm.load %20 : !llvm.ptr<1> -> f32
      %22 = llvm.extractvalue %5[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> 
      %23 = llvm.getelementptr %22[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
      llvm.store %21, %23 : f32, !llvm.ptr<1>
      llvm.return
    }
  }
}
```
So now we can use default passes and not require any additional converters.

And if I add canonicalization we get a nice simple kernel:
```mlir
module {
  gpu.module @kernels {
    llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
    llvm.func spir_kernelcc @no_address_spaces_complex(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) attributes {gpu.kernel} {
      llvm.call @no_address_spaces_callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11) : (!llvm.ptr<1>, !llvm.ptr<1>, i64, i64, i64, i64, i64, !llvm.ptr<1>, !llvm.ptr<1>, i64, i64, i64) -> ()
      llvm.return
    }
    llvm.func @no_address_spaces_callee(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) {
      %0 = llvm.mlir.constant(2 : index) : i64
      %1 = llvm.mlir.constant(0 : i32) : i32
      %2 = llvm.call spir_funccc @_Z12get_group_idj(%1) {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return} : (i32) -> i64
      %3 = llvm.mul %2, %0 : i64
      %4 = llvm.add %3, %2 : i64
      %5 = llvm.getelementptr %arg1[%4] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
      %6 = llvm.load %5 : !llvm.ptr<1> -> f32
      %7 = llvm.getelementptr %arg8[%2] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
      llvm.store %6, %7 : f32, !llvm.ptr<1>
      llvm.return
    }
  }
}
```

https://github.com/llvm/llvm-project/pull/105664


More information about the Mlir-commits mailing list