[Mlir-commits] [mlir] [MLIR][GPU-LLVM] Add in-pass signature update option for opencl kernels (PR #105664)
Petr Kurapov
llvmlistbot at llvm.org
Thu Sep 19 09:07:07 PDT 2024
kurapov-peter wrote:
@victor-eds, how about this a1666d65a95d639deca7e0514076b6cca025027f to remove the external patterns and avoid the `gpu.address_space` attribute to survive the lowering?
Here's how it looks like in an example. Input:
```mlir
gpu.module @kernels {
gpu.func @no_address_spaces_complex(%arg0: memref<2x2xf32>, %arg1: memref<4xf32>) kernel {
func.call @no_address_spaces_callee(%arg0, %arg1) : (memref<2x2xf32>, memref<4xf32>) -> ()
gpu.return
}
func.func @no_address_spaces_callee(%arg0: memref<2x2xf32>, %arg1: memref<4xf32>) {
%block_id = gpu.block_id x
%0 = memref.load %arg0[%block_id, %block_id] : memref<2x2xf32>
memref.store %0, %arg1[%block_id] : memref<4xf32>
func.return
}
}
```
After running the pass we get:
```mlir
module {
gpu.module @kernels {
llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
llvm.func spir_kernelcc @no_address_spaces_complex(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) attributes {gpu.kernel} {
%0 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%1 = llvm.insertvalue %arg7, %0[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%2 = llvm.insertvalue %arg8, %1[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%3 = llvm.insertvalue %arg9, %2[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%4 = llvm.insertvalue %arg10, %3[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%5 = llvm.insertvalue %arg11, %4[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%6 = builtin.unrealized_conversion_cast %5 : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> to memref<4xf32, 1>
%7 = builtin.unrealized_conversion_cast %6 : memref<4xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%8 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%9 = llvm.insertvalue %arg0, %8[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%10 = llvm.insertvalue %arg1, %9[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%11 = llvm.insertvalue %arg2, %10[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%12 = llvm.insertvalue %arg3, %11[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%13 = llvm.insertvalue %arg5, %12[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%14 = llvm.insertvalue %arg4, %13[3, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%15 = llvm.insertvalue %arg6, %14[4, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%16 = builtin.unrealized_conversion_cast %15 : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)> to memref<2x2xf32, 1>
%17 = builtin.unrealized_conversion_cast %16 : memref<2x2xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
func.call @no_address_spaces_callee(%16, %6) : (memref<2x2xf32, 1>, memref<4xf32, 1>) -> ()
llvm.return
}
func.func @no_address_spaces_callee(%arg0: memref<2x2xf32, 1>, %arg1: memref<4xf32, 1>) {
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = llvm.call spir_funccc @_Z12get_group_idj(%0) {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return} : (i32) -> i64
%2 = builtin.unrealized_conversion_cast %1 : i64 to index
%3 = memref.load %arg0[%2, %2] : memref<2x2xf32, 1>
memref.store %3, %arg1[%2] : memref<4xf32, 1>
return
}
}
}
```
So, there're no gpu ops left and the address spaces are correct, although described as plain integer attributes.
If I'm to lower that further to llvm I get all the unrealized casts resolved (`mlir-opt -pass-pipeline="builtin.module(gpu.module(convert-gpu-to-llvm-spv),convert-func-to-llvm,finalize-memref-to-llvm,reconcile-unrealized-casts)"`):
```mlir
module {
gpu.module @kernels {
llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
llvm.func spir_kernelcc @no_address_spaces_complex(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) attributes {gpu.kernel} {
%0 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%1 = llvm.insertvalue %arg7, %0[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%2 = llvm.insertvalue %arg8, %1[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%3 = llvm.insertvalue %arg9, %2[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%4 = llvm.insertvalue %arg10, %3[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%5 = llvm.insertvalue %arg11, %4[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%6 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%7 = llvm.insertvalue %arg0, %6[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%8 = llvm.insertvalue %arg1, %7[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%9 = llvm.insertvalue %arg2, %8[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%10 = llvm.insertvalue %arg3, %9[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%11 = llvm.insertvalue %arg5, %10[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%12 = llvm.insertvalue %arg4, %11[3, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%13 = llvm.insertvalue %arg6, %12[4, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%14 = llvm.extractvalue %13[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%15 = llvm.extractvalue %13[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%16 = llvm.extractvalue %13[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%17 = llvm.extractvalue %13[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%18 = llvm.extractvalue %13[3, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%19 = llvm.extractvalue %13[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%20 = llvm.extractvalue %13[4, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%21 = llvm.extractvalue %5[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%22 = llvm.extractvalue %5[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%23 = llvm.extractvalue %5[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%24 = llvm.extractvalue %5[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%25 = llvm.extractvalue %5[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
llvm.call @no_address_spaces_callee(%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25) : (!llvm.ptr<1>, !llvm.ptr<1>, i64, i64, i64, i64, i64, !llvm.ptr<1>, !llvm.ptr<1>, i64, i64, i64) -> ()
llvm.return
}
llvm.func @no_address_spaces_callee(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) {
%0 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%1 = llvm.insertvalue %arg7, %0[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%2 = llvm.insertvalue %arg8, %1[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%3 = llvm.insertvalue %arg9, %2[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%4 = llvm.insertvalue %arg10, %3[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%5 = llvm.insertvalue %arg11, %4[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%6 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%7 = llvm.insertvalue %arg0, %6[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%8 = llvm.insertvalue %arg1, %7[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%9 = llvm.insertvalue %arg2, %8[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%10 = llvm.insertvalue %arg3, %9[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%11 = llvm.insertvalue %arg5, %10[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%12 = llvm.insertvalue %arg4, %11[3, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%13 = llvm.insertvalue %arg6, %12[4, 1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%14 = llvm.mlir.constant(0 : i32) : i32
%15 = llvm.call spir_funccc @_Z12get_group_idj(%14) {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return} : (i32) -> i64
%16 = llvm.extractvalue %13[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
%17 = llvm.mlir.constant(2 : index) : i64
%18 = llvm.mul %15, %17 : i64
%19 = llvm.add %18, %15 : i64
%20 = llvm.getelementptr %16[%19] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%21 = llvm.load %20 : !llvm.ptr<1> -> f32
%22 = llvm.extractvalue %5[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)>
%23 = llvm.getelementptr %22[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
llvm.store %21, %23 : f32, !llvm.ptr<1>
llvm.return
}
}
}
```
So now we can use default passes and not require any additional converters.
And if I add canonicalization we get a nice simple kernel:
```mlir
module {
gpu.module @kernels {
llvm.func spir_funccc @_Z12get_group_idj(i32) -> i64 attributes {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
llvm.func spir_kernelcc @no_address_spaces_complex(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) attributes {gpu.kernel} {
llvm.call @no_address_spaces_callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11) : (!llvm.ptr<1>, !llvm.ptr<1>, i64, i64, i64, i64, i64, !llvm.ptr<1>, !llvm.ptr<1>, i64, i64, i64) -> ()
llvm.return
}
llvm.func @no_address_spaces_callee(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: !llvm.ptr<1>, %arg8: !llvm.ptr<1>, %arg9: i64, %arg10: i64, %arg11: i64) {
%0 = llvm.mlir.constant(2 : index) : i64
%1 = llvm.mlir.constant(0 : i32) : i32
%2 = llvm.call spir_funccc @_Z12get_group_idj(%1) {memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return} : (i32) -> i64
%3 = llvm.mul %2, %0 : i64
%4 = llvm.add %3, %2 : i64
%5 = llvm.getelementptr %arg1[%4] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
%6 = llvm.load %5 : !llvm.ptr<1> -> f32
%7 = llvm.getelementptr %arg8[%2] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32
llvm.store %6, %7 : f32, !llvm.ptr<1>
llvm.return
}
}
}
```
https://github.com/llvm/llvm-project/pull/105664
More information about the Mlir-commits
mailing list