[Mlir-commits] [mlir] [MLIR][OpenMP] Improve Generic-SPMD kernel detection (PR #137307)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 25 03:27:39 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

The previous implementation assumed that, for a target region to be tagged as Generic-SPMD, it would need to contain a single `teams distribute` loop with a single `parallel wsloop` nested in it. However, this was an overly restrictive set of conditions which resulted in a range of kernels behaving incorrectly.

This patch updates the kernel execution flags identification logic to accept any number of `parallel` regions inside of a single `teams distribute` loop (possibly as part of conditional or loop control flow) as Generic-SPMD.

---
Full diff: https://github.com/llvm/llvm-project/pull/137307.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+37-51) 
- (modified) mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir (+97-46) 


``````````diff
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index dd701da507fc6..3afb374381bdf 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1954,7 +1954,7 @@ LogicalResult TargetOp::verifyRegions() {
 }
 
 static Operation *
-findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
+findCapturedOmpOp(Operation *rootOp,
                   llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
   assert(rootOp && "expected valid operation");
 
@@ -1982,19 +1982,17 @@ findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
     // (i.e. its block's successors can reach it) or if it's not guaranteed to
     // be executed before all exits of the region (i.e. it doesn't dominate all
     // blocks with no successors reachable from the entry block).
-    if (checkSingleMandatoryExec) {
-      Region *parentRegion = op->getParentRegion();
-      Block *parentBlock = op->getBlock();
-
-      for (Block *successor : parentBlock->getSuccessors())
-        if (successor->isReachable(parentBlock))
-          return WalkResult::interrupt();
-
-      for (Block &block : *parentRegion)
-        if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
-            !domInfo.dominates(parentBlock, &block))
-          return WalkResult::interrupt();
-    }
+    Region *parentRegion = op->getParentRegion();
+    Block *parentBlock = op->getBlock();
+
+    for (Block *successor : parentBlock->getSuccessors())
+      if (successor->isReachable(parentBlock))
+        return WalkResult::interrupt();
+
+    for (Block &block : *parentRegion)
+      if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
+          !domInfo.dominates(parentBlock, &block))
+        return WalkResult::interrupt();
 
     // Don't capture this op if it has a not-allowed sibling, and stop recursing
     // into nested operations.
@@ -2017,27 +2015,25 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
 
   // Only allow OpenMP terminators and non-OpenMP ops that have known memory
   // effects, but don't include a memory write effect.
-  return findCapturedOmpOp(
-      *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
-        if (!sibling)
-          return false;
-
-        if (ompDialect == sibling->getDialect())
-          return sibling->hasTrait<OpTrait::IsTerminator>();
-
-        if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
-          SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
-              effects;
-          memOp.getEffects(effects);
-          return !llvm::any_of(
-              effects, [&](MemoryEffects::EffectInstance &effect) {
-                return isa<MemoryEffects::Write>(effect.getEffect()) &&
-                       isa<SideEffects::AutomaticAllocationScopeResource>(
-                           effect.getResource());
-              });
-        }
-        return true;
+  return findCapturedOmpOp(*this, [&](Operation *sibling) {
+    if (!sibling)
+      return false;
+
+    if (ompDialect == sibling->getDialect())
+      return sibling->hasTrait<OpTrait::IsTerminator>();
+
+    if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
+      SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
+          effects;
+      memOp.getEffects(effects);
+      return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+        return isa<MemoryEffects::Write>(effect.getEffect()) &&
+               isa<SideEffects::AutomaticAllocationScopeResource>(
+                   effect.getResource());
       });
+    }
+    return true;
+  });
 }
 
 TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
@@ -2098,33 +2094,23 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
     if (isa<LoopOp>(innermostWrapper))
       return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
 
-    // Find single immediately nested captured omp.parallel and add spmd flag
-    // (generic-spmd case).
+    // Add spmd flag if there's a nested omp.parallel (generic-spmd case).
     //
     // TODO: This shouldn't have to be done here, as it is too easy to break.
     // The openmp-opt pass should be updated to be able to promote kernels like
     // this from "Generic" to "Generic-SPMD". However, the use of the
     // `kmpc_distribute_static_loop` family of functions produced by the
     // OMPIRBuilder for these kernels prevents that from working.
-    Dialect *ompDialect = targetOp->getDialect();
-    Operation *nestedCapture = findCapturedOmpOp(
-        capturedOp, /*checkSingleMandatoryExec=*/false,
-        [&](Operation *sibling) {
-          return sibling && (ompDialect != sibling->getDialect() ||
-                             sibling->hasTrait<OpTrait::IsTerminator>());
-        });
+    bool hasParallel = capturedOp
+                           ->walk<WalkOrder::PreOrder>([](ParallelOp) {
+                             return WalkResult::interrupt();
+                           })
+                           .wasInterrupted();
 
     TargetRegionFlags result =
         TargetRegionFlags::generic | TargetRegionFlags::trip_count;
 
-    if (!nestedCapture)
-      return result;
-
-    while (nestedCapture->getParentOp() != capturedOp)
-      nestedCapture = nestedCapture->getParentOp();
-
-    return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
-                                          : result;
+    return hasParallel ? result | TargetRegionFlags::spmd : result;
   }
   // Detect target-parallel-wsloop[-simd].
   else if (isa<WsloopOp>(innermostWrapper)) {
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
index 8101660e571e4..3273de0c26d27 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -1,11 +1,7 @@
-// RUN: split-file %s %t
-// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
-// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
-
-//--- host.mlir
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
 module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
-  llvm.func @main(%arg0 : !llvm.ptr) {
+  llvm.func @host(%arg0 : !llvm.ptr) {
     %x = llvm.load %arg0 : !llvm.ptr -> i32
     %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
     omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
@@ -32,36 +28,36 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
   }
 }
 
-// HOST-LABEL: define void @main
-// HOST:         %omp_loop.tripcount = {{.*}}
-// HOST-NEXT:    br label %[[ENTRY:.*]]
-// HOST:       [[ENTRY]]:
-// HOST:         %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
-// HOST:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
-// HOST-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
-// HOST:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
-// HOST-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
-// HOST-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
-// HOST:       [[OFFLOAD_FAILED]]:
-// HOST:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// CHECK-LABEL: define void @host
+// CHECK:         %omp_loop.tripcount = {{.*}}
+// CHECK-NEXT:    br label %[[ENTRY:.*]]
+// CHECK:       [[ENTRY]]:
+// CHECK:         %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// CHECK:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// CHECK-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// CHECK:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// CHECK-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// CHECK-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// CHECK:       [[OFFLOAD_FAILED]]:
+// CHECK:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
 
-// HOST:       define internal void @[[TARGET_OUTLINE]]
-// HOST:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+// CHECK:       define internal void @[[TARGET_OUTLINE]]
+// CHECK:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
 
-// HOST:       define internal void @[[TEAMS_OUTLINE]]
-// HOST:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+// CHECK:       define internal void @[[TEAMS_OUTLINE]]
+// CHECK:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
 
-// HOST:       define internal void @[[DISTRIBUTE_OUTLINE]]
-// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
-// HOST:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
+// CHECK:       define internal void @[[DISTRIBUTE_OUTLINE]]
+// CHECK:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// CHECK:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
 
-// HOST:       define internal void @[[PARALLEL_OUTLINE]]
-// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// CHECK:       define internal void @[[PARALLEL_OUTLINE]]
+// CHECK:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
 
-//--- device.mlir
+// -----
 
 module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
-  llvm.func @main(%arg0 : !llvm.ptr) {
+  llvm.func @device(%arg0 : !llvm.ptr) {
     %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
     omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
       %x = llvm.load %ptr : !llvm.ptr -> i32
@@ -87,25 +83,80 @@ module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_devic
   }
 }
 
-// DEVICE:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
-// DEVICE:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
-// DEVICE:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
-// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
-// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
+// CHECK:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// CHECK:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// CHECK:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
+// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
+// CHECK:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// CHECK:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// CHECK:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// CHECK:        call void @__kmpc_target_deinit()
+
+// CHECK:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// CHECK:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
+
+// CHECK:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
+
+// CHECK:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+
+// CHECK:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_for_static_loop{{.*}}({{.*}})
+
+// -----
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
+  llvm.func @device2(%arg0 : !llvm.ptr) {
+    %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+    omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
+      %x = llvm.load %ptr : !llvm.ptr -> i32
+      omp.teams {
+        omp.distribute {
+          omp.loop_nest (%iv1) : i32 = (%x) to (%x) step (%x) {
+            omp.parallel {
+              omp.terminator
+            }
+            llvm.br ^bb2
+          ^bb1:
+            omp.parallel {
+              omp.terminator
+            }
+            omp.yield
+          ^bb2:
+            llvm.br ^bb1
+          }
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// CHECK:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// CHECK:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// CHECK:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
+// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
 
-// DEVICE:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
-// DEVICE:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
-// DEVICE:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
-// DEVICE:        call void @__kmpc_target_deinit()
+// CHECK:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// CHECK:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// CHECK:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// CHECK:        call void @__kmpc_target_deinit()
 
-// DEVICE:      define internal void @[[TARGET_OUTLINE]]({{.*}})
-// DEVICE:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
+// CHECK:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// CHECK:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
 
-// DEVICE:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
-// DEVICE:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
+// CHECK:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
 
-// DEVICE:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
-// DEVICE:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+// CHECK:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// CHECK:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE0:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+// CHECK:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE1:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
 
-// DEVICE:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
-// DEVICE:        call void @__kmpc_for_static_loop{{.*}}({{.*}})
+// CHECK:      define internal void @[[PARALLEL_OUTLINE1]]({{.*}})
+// CHECK:      define internal void @[[PARALLEL_OUTLINE0]]({{.*}})

``````````

</details>


https://github.com/llvm/llvm-project/pull/137307


More information about the Mlir-commits mailing list