[Mlir-commits] [mlir] [MLIR][OpenMP] Improve Generic-SPMD kernel detection (PR #137307)
Sergio Afonso
llvmlistbot at llvm.org
Wed Apr 30 04:40:16 PDT 2025
skatrak wrote:
I think I probably didn't explain very well the cases this patch covers, since having a `parallel` construct inside of a loop or conditional block is already accepted as part of the Generic-SPMD pattern (`checkSingleMandatoryExec=false` on that call to `findCapturedOmpOp`). What this does is also allow multiple consecutive `parallel` constructs or, more generally, multiple OpenMP constructs within the region.
I absolutely agree with you that this seems counterintuitive. But I think the main point here is that we're not tagging these kernels as "SPMD", but rather "Generic-SPMD". This is in contrast to just "Generic", which is what we're currently doing. The reason is that, in practice, if a `parallel` region appears inside of a Generic kernel, it doesn't seem to run properly. The following tests show some cases that don't work without this patch:
```f90
! This works (it's already tagged Generic-SPMD):
! condition=true: 1 1 1 1 1 1
! condition=false: 2 2 2 2 2 2
subroutine if_cond_single(condition)
implicit none
logical, intent(in) :: condition
integer, parameter :: M = 2, N = 3
integer :: i, j
integer :: v(M,N)
v(:,:) = 0
!$omp target teams distribute
do i=1, M
if (condition) then
!$omp parallel do
do j=1, N
v(i, j) = v(i, j) + 1
end do
else
do j=1, N
v(i, j) = v(i, j) + 2
end do
end if
end do
print *, v(:,:)
end subroutine
! This doesn't work without this patch:
! condition=true: 0 0 0 0 0 0
! condition=false: 0 0 0 0 0 0
subroutine if_cond_multiple(condition)
implicit none
logical, intent(in) :: condition
integer, parameter :: M = 2, N = 3
integer :: i, j
integer :: v(M,N)
v(:,:) = 0
!$omp target teams distribute
do i=1, M
if (condition) then
!$omp parallel do
do j=1, N
v(i, j) = v(i, j) + 1
end do
else
!$omp parallel do
do j=1, N
v(i, j) = v(i, j) + 2
end do
end if
end do
print *, v(:,:)
end subroutine
! This works (it's already tagged Generic-SPMD):
! 3 3 2 2 2 2
subroutine single_parallel()
implicit none
integer, parameter :: M = 2, N = 3
integer :: i, j
integer :: v(M,N)
v(:,:) = 0
!$omp target teams distribute
do i=1, M
!$omp parallel do
do j=1, N
v(i, j) = v(i, j) + 1
end do
v(i, 1) = v(i, 1) + 1
do j=1, N
v(i, j) = v(i, j) + 1
end do
end do
print *, v(:,:)
end subroutine
! This doesn't work without this patch:
! 1 1 0 0 0 0
subroutine multi_parallel()
implicit none
integer, parameter :: M = 2, N = 3
integer :: i, j
integer :: v(M,N)
v(:,:) = 0
!$omp target teams distribute
do i=1, M
!$omp parallel do
do j=1, N
v(i, j) = v(i, j) + 1
end do
v(i, 1) = v(i, 1) + 1
!$omp parallel do
do j=1, N
v(i, j) = v(i, j) + 1
end do
end do
print *, v(:,:)
end subroutine
```
I'm no expert on the exact uses of Generic-SPMD, but making it mean roughly "a `target teams distribute` construct with at least one `parallel` region inside" is so far what makes most applications and tests we're looking at work. SPMD is only used for `target teams distribute parallel do` composite constructs and Generic is everything else. I'm sure we'll have to tune this detection further, but I believe this change doesn't break anything we knew to be already working and it does make other cases work.
https://github.com/llvm/llvm-project/pull/137307
More information about the Mlir-commits
mailing list