[Mlir-commits] [mlir] [MLIR][OpenMP] Improve Generic-SPMD kernel detection (PR #137307)

Johannes Doerfert llvmlistbot at llvm.org
Tue May 20 12:09:49 PDT 2025


jdoerfert wrote:

> This other test that @Meinersbur made, however, does show another case where running in Generic-SPMD mode is currently required in order to get the expected results:

Long story short, the result of O0 is correct and expected. 
The result you see with O1-O3 is also correct and expected.

What's going on:

First, note that `#pragma omp parallel` can choose less threads than the user requested, so 1 is always a valid option (up to the strict modifier introduction which we don't implement yet).
Next a few design choices:

- When we run in Generic-mode we need an extra warp for the main thread. We count that warp against the thread limit, which one could argue we shouldn't. As we do, a thread limit of 10 doesn't allow for 2 warps, hence we end up with the single "main thread warp" and no workers at all. We can lift the thread limit to 2x WarpSize but that alone won't make a difference.
- [We choose not to run generic-mode parallel regions with partial warps](https://github.com/llvm/llvm-project/blob/7e9d9dba9cabb4cd840e5f34a07729f8fdc2112e/offload/DeviceRTL/src/Parallelism.cpp#L56). Honestly, I don't remember why, nor if I introduced this originally. My best guess is that it makes the Generic-mode barriers simpler/doable. One could look into this choice.

So, for those 2 reasons we see 1 thread parallel regions in the example. Again, that's perfectly valid OpenMP. If you add `omp for` statements you can see we properly "workshare" among the available threads (1 or more):

```
#include <omp.h>
#include <stdio.h>

int main() {
  int i, j, a = 0, b = 0, c = 0, g = 21;

#pragma omp target teams distribute thread_limit(128) private(i, j)            \
    reduction(+ : a, b, c, g)
  for (i = 1; i <= 10; ++i) {
    j = i;
    if (j == 5) {
      g += 10 * omp_get_team_num() + omp_get_thread_num();
      ++c;
      j = 11;
    }
    if (j == 11) {
#pragma omp parallel num_threads(64) reduction(+ : a)
#ifdef WS1
#pragma omp for
      for (int k = 0; k < 10; ++k)
#endif
        ++a;
    } else {
#pragma omp parallel num_threads(10) reduction(+ : b)
#ifdef WS2
#pragma omp for
      for (int k = 0; k < 10; ++k)
#endif
        ++b;
    }
  }

  printf("a: %d\nb: %d\nc: %d\ng: %d", a, b, c, g);
  return 0;
}
```

The fun part is, I run into a hang if you enable WS2; see #140786 for more information.

https://github.com/llvm/llvm-project/pull/137307


More information about the Mlir-commits mailing list