[Mlir-commits] [mlir] [MLIR][OpenMP] Reduce overhead of target compilation (PR #130945)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 12 04:45:58 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch avoids calling `TargetOp::getInnermostCapturedOmpOp` multiple times during initialization of default and runtime target attributes in MLIR to LLVM IR translation of `omp.target` operations. This is a potentially expensive operation, so this change should help keep compile times lower.

---
Full diff: https://github.com/llvm/llvm-project/pull/130945.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+10-1) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+11-4) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+11-9) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index c1dae8d543eef..e2253d16f58af 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1318,11 +1318,20 @@ def TargetOp : OpenMP_Op<"target", traits = [
     ///
     /// If there are omp.loop_nest operations in the sequence of nested
     /// operations, the top level one will be the one captured.
+    ///
+    /// This is a relatively expensive operation, so if it is needed at the same
+    /// time as the result of `getKernelExecFlags` it is preferable to cache the
+    /// result of this function and pass it along.
     Operation *getInnermostCapturedOmpOp();
 
     /// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
     /// contents of the target region.
-    llvm::omp::OMPTgtExecModeFlags getKernelExecFlags();
+    ///
+    /// \param capturedOp result of a still valid (no modifications made to any
+    /// nested operations) previous call to `getInnermostCapturedOmpOp()`. If
+    /// not specified, this will call that function itself instead.
+    llvm::omp::OMPTgtExecModeFlags
+    getKernelExecFlags(std::optional<Operation *> capturedOp = std::nullopt);
   }] # clausesExtraClassDeclaration;
 
   let assemblyFormat = clausesAssemblyFormat # [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 65a296c5d4829..de1dc00909b40 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2025,17 +2025,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
   return capturedOp;
 }
 
-llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
+llvm::omp::OMPTgtExecModeFlags
+TargetOp::getKernelExecFlags(std::optional<Operation *> capturedOp) {
   using namespace llvm::omp;
 
+  // Use a cached operation, if passed in. Otherwise, find the innermost
+  // captured operation.
+  if (!capturedOp)
+    capturedOp = getInnermostCapturedOmpOp();
+  assert(*capturedOp == getInnermostCapturedOmpOp() &&
+         "unexpected captured op");
+
   // Make sure this region is capturing a loop. Otherwise, it's a generic
   // kernel.
-  Operation *capturedOp = getInnermostCapturedOmpOp();
-  if (!isa_and_present<LoopNestOp>(capturedOp))
+  if (!isa_and_present<LoopNestOp>(*capturedOp))
     return OMP_TGT_EXEC_MODE_GENERIC;
 
   SmallVector<LoopWrapperInterface> wrappers;
-  cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
+  cast<LoopNestOp>(*capturedOp).gatherWrappers(wrappers);
   assert(!wrappers.empty());
 
   // Ignore optional SIMD leaf construct.
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3373f19a006ba..6de5492edc60b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4558,11 +4558,10 @@ static std::optional<int64_t> extractConstInteger(Value value) {
 /// function for the target region, so that they can be used to initialize the
 /// corresponding global `ConfigurationEnvironmentTy` structure.
 static void
-initTargetDefaultAttrs(omp::TargetOp targetOp,
+initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
                        llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
                        bool isTargetDevice) {
   // TODO: Handle constant 'if' clauses.
-  Operation *capturedOp = targetOp.getInnermostCapturedOmpOp();
 
   Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
   if (!isTargetDevice) {
@@ -4644,7 +4643,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
     combinedMaxThreadsVal = maxThreadsVal;
 
   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
-  attrs.ExecFlags = targetOp.getKernelExecFlags();
+  attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
@@ -4660,10 +4659,9 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
 static void
 initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
                        LLVM::ModuleTranslation &moduleTranslation,
-                       omp::TargetOp targetOp,
+                       omp::TargetOp targetOp, Operation *capturedOp,
                        llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
-  omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
-      targetOp.getInnermostCapturedOmpOp());
+  omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
   Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
@@ -4690,7 +4688,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   if (numThreads)
     attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
 
-  if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+  if (targetOp.getKernelExecFlags(capturedOp) !=
+      llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
     llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
     attrs.LoopTripCount = nullptr;
 
@@ -4940,12 +4939,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
   llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
-  initTargetDefaultAttrs(targetOp, defaultAttrs, isTargetDevice);
+  Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
+  initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
+                         isTargetDevice);
 
   // Collect host-evaluated values needed to properly launch the kernel from the
   // host.
   if (!isTargetDevice)
-    initTargetRuntimeAttrs(builder, moduleTranslation, targetOp, runtimeAttrs);
+    initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
+                           targetCapturedOp, runtimeAttrs);
 
   // Pass host-evaluated values as parameters to the kernel / host fallback,
   // except if they are constants. In any case, map the MLIR block argument to

``````````

</details>


https://github.com/llvm/llvm-project/pull/130945


More information about the Mlir-commits mailing list