[llvm] [mlir] [WIP][MLIR][OpenMP] Support Generic-SPMD optimization for OpenMP dialect (PR #140918)

Sergio Afonso via llvm-commits llvm-commits at lists.llvm.org
Wed May 21 08:57:13 PDT 2025


https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/140918

This patch attempts to move the logic currently implemented for the OpenMP dialect at the MLIR level to tag kernels as Generic-SPMD to instead rely on the OpenMPOpt pass, as Clang already does.

This requires two changes:
1. Modify `TargetOp::getKernelExecFlags` and associated codegen to only tag kernels as SPMD or Generic.
2. Update the OpenMPOpt pass to be able to support the `__kmpc_distribute_static_loop_*`, `__kmpc_distribute_for_static_loop_*` and `__kmpc_for_static_loop_*` families of DeviceRTL functions when transforming Generic kernels to Generic-SPMD.

This work in process implements the first change and makes an attempt at the second, but it does not work yet. In addition to marking the previously mentioned functions as compatible with SPMD mode, it attempts to analyze the outlined loop body as if there was a direct call to it (there is an indirect call, so it should be taken into account), but this does not currently result in the right code changes.

>From b05150d10c91673ef814cc90f274d04832d21a7c Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 21 May 2025 16:29:23 +0100
Subject: [PATCH 1/2] Remove Generic-SPMD handling from MLIR

---
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        |  6 +--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 48 +++++--------------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 12 ++---
 3 files changed, 16 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 9dbe6897a3304..2ba712d6de250 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -227,15 +227,13 @@ def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
 //===----------------------------------------------------------------------===//
 
 def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
-def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
-def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
-def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
+def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
+def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
 
 def TargetRegionFlags : OpenMP_BitEnumAttr<
     "TargetRegionFlags",
     "target region property flags", [
       TargetRegionFlagsNone,
-      TargetRegionFlagsGeneric,
       TargetRegionFlagsSpmd,
       TargetRegionFlagsTripCount
     ]>;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index deff86d5c5ecb..5a9687fdd195e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2083,7 +2083,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
 
   // If it's not capturing a loop, it's a default target region.
   if (!isa_and_present<LoopNestOp>(capturedOp))
-    return TargetRegionFlags::generic;
+    return TargetRegionFlags::none;
 
   // Get the innermost non-simd loop wrapper.
   SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2096,24 +2096,24 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
 
   auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
   if (numWrappers != 1 && numWrappers != 2)
-    return TargetRegionFlags::generic;
+    return TargetRegionFlags::none;
 
   // Detect target-teams-distribute-parallel-wsloop[-simd].
   if (numWrappers == 2) {
     if (!isa<WsloopOp>(innermostWrapper))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     innermostWrapper = std::next(innermostWrapper);
     if (!isa<DistributeOp>(innermostWrapper))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     Operation *parallelOp = (*innermostWrapper)->getParentOp();
     if (!isa_and_present<ParallelOp>(parallelOp))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     Operation *teamsOp = parallelOp->getParentOp();
     if (!isa_and_present<TeamsOp>(teamsOp))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     if (teamsOp->getParentOp() == targetOp.getOperation())
       return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
@@ -2122,53 +2122,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
   else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
     Operation *teamsOp = (*innermostWrapper)->getParentOp();
     if (!isa_and_present<TeamsOp>(teamsOp))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     if (teamsOp->getParentOp() != targetOp.getOperation())
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     if (isa<LoopOp>(innermostWrapper))
       return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
 
-    // Find single immediately nested captured omp.parallel and add spmd flag
-    // (generic-spmd case).
-    //
-    // TODO: This shouldn't have to be done here, as it is too easy to break.
-    // The openmp-opt pass should be updated to be able to promote kernels like
-    // this from "Generic" to "Generic-SPMD". However, the use of the
-    // `kmpc_distribute_static_loop` family of functions produced by the
-    // OMPIRBuilder for these kernels prevents that from working.
-    Dialect *ompDialect = targetOp->getDialect();
-    Operation *nestedCapture = findCapturedOmpOp(
-        capturedOp, /*checkSingleMandatoryExec=*/false,
-        [&](Operation *sibling) {
-          return sibling && (ompDialect != sibling->getDialect() ||
-                             sibling->hasTrait<OpTrait::IsTerminator>());
-        });
-
-    TargetRegionFlags result =
-        TargetRegionFlags::generic | TargetRegionFlags::trip_count;
-
-    if (!nestedCapture)
-      return result;
-
-    while (nestedCapture->getParentOp() != capturedOp)
-      nestedCapture = nestedCapture->getParentOp();
-
-    return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
-                                          : result;
+    return TargetRegionFlags::trip_count;
   }
   // Detect target-parallel-wsloop[-simd].
   else if (isa<WsloopOp>(innermostWrapper)) {
     Operation *parallelOp = (*innermostWrapper)->getParentOp();
     if (!isa_and_present<ParallelOp>(parallelOp))
-      return TargetRegionFlags::generic;
+      return TargetRegionFlags::none;
 
     if (parallelOp->getParentOp() == targetOp.getOperation())
       return TargetRegionFlags::spmd;
   }
 
-  return TargetRegionFlags::generic;
+  return TargetRegionFlags::none;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d4b2f4154ae53..87e29d054793b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -5192,16 +5192,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
   omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
-  assert(
-      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
-                                               omp::TargetRegionFlags::spmd) &&
-      "invalid kernel flags");
   attrs.ExecFlags =
-      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
-          ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
-                ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
-                : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
-          : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
+      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
+          ? llvm::omp::OMP_TGT_EXEC_MODE_SPMD
+          : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;

>From 93f029a2d899925e220fa656836cd0d9d495b154 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 21 May 2025 16:32:28 +0100
Subject: [PATCH 2/2] Attempt at supporting new DeviceRTL loop functions to the
 OpenMPOpt pass

---
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 110 ++++++++++++++++++--------
 1 file changed, 79 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 562c5fcd05386..cc2063ced846d 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -67,6 +67,21 @@ using namespace omp;
 
 #define DEBUG_TYPE "openmp-opt"
 
+static bool isOMPRTLLoopFunc(RuntimeFunction RF) {
+  return RF == OMPRTL___kmpc_distribute_static_loop_4 ||
+         RF == OMPRTL___kmpc_distribute_static_loop_4u ||
+         RF == OMPRTL___kmpc_distribute_static_loop_8 ||
+         RF == OMPRTL___kmpc_distribute_static_loop_8u ||
+         RF == OMPRTL___kmpc_distribute_for_static_loop_4 ||
+         RF == OMPRTL___kmpc_distribute_for_static_loop_4u ||
+         RF == OMPRTL___kmpc_distribute_for_static_loop_8 ||
+         RF == OMPRTL___kmpc_distribute_for_static_loop_8u ||
+         RF == OMPRTL___kmpc_for_static_loop_4 ||
+         RF == OMPRTL___kmpc_for_static_loop_4u ||
+         RF == OMPRTL___kmpc_for_static_loop_8 ||
+         RF == OMPRTL___kmpc_for_static_loop_8u;
+}
+
 static cl::opt<bool> DisableOpenMPOptimizations(
     "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
     cl::Hidden, cl::init(false));
@@ -4560,13 +4575,15 @@ struct AAKernelInfoFunction : AAKernelInfo {
     // parallel regions we expect, if there are any.
     for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
       auto *CB = ReachedKnownParallelRegions[I];
-      auto *ParallelRegion = dyn_cast<Function>(
-          CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
+      auto *ParallelRegion =
+          CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts();
       BasicBlock *PRExecuteBB = BasicBlock::Create(
           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
           StateMachineEndParallelBB);
-      CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
-          ->setDebugLoc(DLoc);
+      if (auto *ParallelRegionFn = dyn_cast<Function>(ParallelRegion)) {
+        CallInst::Create(ParallelRegionFn, {ZeroArg, GTid}, "", PRExecuteBB)
+            ->setDebugLoc(DLoc);
+      }
       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
           ->setDebugLoc(DLoc);
 
@@ -4892,37 +4909,41 @@ struct AAKernelInfoCallSite : AAKernelInfo {
       return;
     }
 
-    // Next we check if we know the callee. If it is a known OpenMP function
-    // we will handle them explicitly in the switch below. If it is not, we
-    // will use an AAKernelInfo object on the callee to gather information and
-    // merge that into the current state. The latter happens in the updateImpl.
+    // Next we check if we know the callee.
+    // If it is not a known OpenMP function we use an AAKernelInfo object on the
+    // callee to gather information and merge that into the current state. This
+    // will happen later in the updateImpl.
+    auto CheckNonOmpRTLCallee = [&](Function *Callee) {
+      // Unknown caller or declarations are not analyzable, we give up.
+      if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
+        // Unknown callees might contain parallel regions, except if they have
+        // an appropriate assumption attached.
+        if (!AssumptionAA ||
+            !(AssumptionAA->hasAssumption("omp_no_openmp") ||
+              AssumptionAA->hasAssumption("omp_no_parallelism")))
+          ReachedUnknownParallelRegions.insert(&CB);
+
+        // If SPMDCompatibilityTracker is not fixed, we need to give up on
+        // the idea we can run something unknown in SPMD-mode.
+        if (!SPMDCompatibilityTracker.isAtFixpoint()) {
+          SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+          SPMDCompatibilityTracker.insert(&CB);
+        }
+
+        // We have updated the state for this unknown call properly, there
+        // won't be any change so we indicate a fixpoint.
+        indicateOptimisticFixpoint();
+      }
+    };
+
+    // Known OpenMP functions will be handled explicitly in the switch below.
     auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
       if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
-        // Unknown caller or declarations are not analyzable, we give up.
-        if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
-
-          // Unknown callees might contain parallel regions, except if they have
-          // an appropriate assumption attached.
-          if (!AssumptionAA ||
-              !(AssumptionAA->hasAssumption("omp_no_openmp") ||
-                AssumptionAA->hasAssumption("omp_no_parallelism")))
-            ReachedUnknownParallelRegions.insert(&CB);
-
-          // If SPMDCompatibilityTracker is not fixed, we need to give up on the
-          // idea we can run something unknown in SPMD-mode.
-          if (!SPMDCompatibilityTracker.isAtFixpoint()) {
-            SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-            SPMDCompatibilityTracker.insert(&CB);
-          }
-
-          // We have updated the state for this unknown call properly, there
-          // won't be any change so we indicate a fixpoint.
-          indicateOptimisticFixpoint();
-        }
         // If the callee is known and can be used in IPO, we will update the
         // state based on the callee state in updateImpl.
+        CheckNonOmpRTLCallee(Callee);
         return;
       }
       if (NumCallees > 1) {
@@ -5001,6 +5022,27 @@ struct AAKernelInfoCallSite : AAKernelInfo {
           break;
         };
       } break;
+      case OMPRTL___kmpc_distribute_static_loop_4:
+      case OMPRTL___kmpc_distribute_static_loop_4u:
+      case OMPRTL___kmpc_distribute_static_loop_8:
+      case OMPRTL___kmpc_distribute_static_loop_8u:
+      case OMPRTL___kmpc_distribute_for_static_loop_4:
+      case OMPRTL___kmpc_distribute_for_static_loop_4u:
+      case OMPRTL___kmpc_distribute_for_static_loop_8:
+      case OMPRTL___kmpc_distribute_for_static_loop_8u:
+      case OMPRTL___kmpc_for_static_loop_4:
+      case OMPRTL___kmpc_for_static_loop_4u:
+      case OMPRTL___kmpc_for_static_loop_8:
+      case OMPRTL___kmpc_for_static_loop_8u: {
+        // Check the contents of the callback, which contains the body executed
+        // in a loop by these functions.
+        unsigned CallBackArgOpNo = 1;
+        auto *CallBackFunc = cast<Function>(CB.getArgOperand(CallBackArgOpNo));
+        // If the callee is known and can be used in IPO, we will update the
+        // state based on the callee state in updateImpl.
+        CheckNonOmpRTLCallee(CallBackFunc);
+        return;
+      }
       case OMPRTL___kmpc_target_init:
         KernelInitCB = &CB;
         break;
@@ -5057,11 +5099,18 @@ struct AAKernelInfoCallSite : AAKernelInfo {
     KernelInfoState StateBefore = getState();
 
     auto CheckCallee = [&](Function *F, int NumCallees) {
+      CallBase &CB = cast<CallBase>(getAssociatedValue());
       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
+      bool IsStaticLoopFn = It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
+                            isOMPRTLLoopFunc(It->getSecond());
 
       // If F is not a runtime function, propagate the AAKernelInfo of the
       // callee.
-      if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
+      if (It == OMPInfoCache.RuntimeFunctionIDMap.end() || IsStaticLoopFn) {
+        if (IsStaticLoopFn) {
+          unsigned CallBackArgOpNo = 1;
+          F = cast<Function>(CB.getArgOperand(CallBackArgOpNo));
+        }
         const IRPosition &FnPos = IRPosition::function(*F);
         auto *FnAA =
             A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
@@ -5075,7 +5124,6 @@ struct AAKernelInfoCallSite : AAKernelInfo {
       if (NumCallees > 1)
         return indicatePessimisticFixpoint();
 
-      CallBase &CB = cast<CallBase>(getAssociatedValue());
       if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
         if (!handleParallel51(A, CB))
           return indicatePessimisticFixpoint();



More information about the llvm-commits mailing list