[llvm-branch-commits] [llvm] [mlir] [OMPIRBuilder][MLIR] Add support for target 'if' clause (PR #122478)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 10 07:50:20 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the `omp.target` MLIR operation to make use of this new feature.

---
Full diff: https://github.com/llvm/llvm-project/pull/122478.diff


6 Files Affected:

- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+15-11) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+77-53) 
- (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+6-5) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-3) 
- (added) mlir/test/Target/LLVMIR/omptarget-if.mlir (+68) 
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-11) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4ce47b1c05d9b0..b1a23996c7bdd2 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2965,21 +2965,25 @@ class OpenMPIRBuilder {
   /// \param NumThreads Number of teams specified in the thread_limit clause.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
+  /// \param IfCond value of the `if` clause.
   /// \param BodyGenCB Callback that will generate the region code.
   /// \param ArgAccessorFuncCB Callback that will generate accessors
   /// instructions for passed in target arguments where neccessary
   /// \param Dependencies A vector of DependData objects that carry
-  // dependency information as passed in the depend clause
-  // \param HasNowait Whether the target construct has a `nowait` clause or not.
-  InsertPointOrErrorTy createTarget(
-      const LocationDescription &Loc, bool IsOffloadEntry,
-      OpenMPIRBuilder::InsertPointTy AllocaIP,
-      OpenMPIRBuilder::InsertPointTy CodeGenIP,
-      TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
-      ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
-      GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
-      TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
-      SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
+  /// dependency information as passed in the depend clause
+  /// \param HasNowait Whether the target construct has a `nowait` clause or
+  /// not.
+  InsertPointOrErrorTy
+  createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               OpenMPIRBuilder::InsertPointTy CodeGenIP,
+               TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+               Value *IfCond, GenMapInfoCallbackTy GenMapInfoCB,
+               TargetBodyGenCallbackTy BodyGenCB,
+               TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+               SmallVector<DependData> Dependencies = {},
+               bool HasNowait = false);
 
   /// Returns __kmpc_for_static_init_* runtime function for the specified
   /// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index db77c6a5869764..0e190f4c64a8b3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7310,6 +7310,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
                Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
                ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+               Value *IfCond,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
                SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
                bool HasNoWait = false) {
@@ -7354,9 +7355,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     return Error::success();
   };
 
-  // If we don't have an ID for the target region, it means an offload entry
-  // wasn't created. In this case we just run the host fallback directly.
-  if (!OutlinedFnID) {
+  auto &&EmitTargetCallElse =
+      [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
     // Assume no error was returned because EmitTargetCallFallbackCB doesn't
     // produce any.
     OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7372,65 +7373,87 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     }());
 
     Builder.restoreIP(AfterIP);
-    return;
-  }
+    return Error::success();
+  };
+
+  auto &&EmitTargetCallThen =
+      [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
+    OpenMPIRBuilder::TargetDataInfo Info(
+        /*RequiresDevicePointerInfo=*/false,
+        /*SeparateBeginEndCalls=*/true);
+
+    OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
+    OpenMPIRBuilder::TargetDataRTArgs RTArgs;
+    OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
+                                           RTArgs, MapInfo,
+                                           /*IsNonContiguous=*/true,
+                                           /*ForEndCall=*/false);
+
+    SmallVector<Value *, 3> NumTeamsC;
+    SmallVector<Value *, 3> NumThreadsC;
+    for (auto V : NumTeams)
+      NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+    for (auto V : NumThreads)
+      NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+
+    unsigned NumTargetItems = Info.NumberOfPtrs;
+    // TODO: Use correct device ID
+    Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
+    uint32_t SrcLocStrSize;
+    Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+    Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
+                                               llvm::omp::IdentFlag(0), 0);
+    // TODO: Use correct NumIterations
+    Value *NumIterations = Builder.getInt64(0);
+    // TODO: Use correct DynCGGroupMem
+    Value *DynCGGroupMem = Builder.getInt32(0);
+
+    KArgs = OpenMPIRBuilder::TargetKernelArgs(
+        NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
+        DynCGGroupMem, HasNoWait);
+
+    // Assume no error was returned because TaskBodyCB and
+    // EmitTargetCallFallbackCB don't produce any.
+    OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
+      // The presence of certain clauses on the target directive require the
+      // explicit generation of the target task.
+      if (RequiresOuterTargetTask)
+        return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
+                                         Dependencies, HasNoWait);
+
+      return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
+                                         EmitTargetCallFallbackCB, KArgs,
+                                         DeviceID, RTLoc, AllocaIP);
+    }());
 
-  OpenMPIRBuilder::TargetDataInfo Info(
-      /*RequiresDevicePointerInfo=*/false,
-      /*SeparateBeginEndCalls=*/true);
+    Builder.restoreIP(AfterIP);
+    return Error::success();
+  };
 
-  OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
-  OpenMPIRBuilder::TargetDataRTArgs RTArgs;
-  OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
-                                         RTArgs, MapInfo,
-                                         /*IsNonContiguous=*/true,
-                                         /*ForEndCall=*/false);
+  // If we don't have an ID for the target region, it means an offload entry
+  // wasn't created. In this case we just run the host fallback directly.
+  if (!OutlinedFnID) {
+    cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
+    return;
+  }
 
-  SmallVector<Value *, 3> NumTeamsC;
-  SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : NumTeams)
-    NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : NumThreads)
-    NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+  // If there's no IF clause, only generate the kernel launch code path.
+  if (!IfCond) {
+    cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
+    return;
+  }
 
-  unsigned NumTargetItems = Info.NumberOfPtrs;
-  // TODO: Use correct device ID
-  Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
-  uint32_t SrcLocStrSize;
-  Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
-  Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
-                                             llvm::omp::IdentFlag(0), 0);
-  // TODO: Use correct NumIterations
-  Value *NumIterations = Builder.getInt64(0);
-  // TODO: Use correct DynCGGroupMem
-  Value *DynCGGroupMem = Builder.getInt32(0);
-
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(
-      NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
-      DynCGGroupMem, HasNoWait);
-
-  // Assume no error was returned because TaskBodyCB and
-  // EmitTargetCallFallbackCB don't produce any.
-  OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
-    // The presence of certain clauses on the target directive require the
-    // explicit generation of the target task.
-    if (RequiresOuterTargetTask)
-      return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
-                                       Dependencies, HasNoWait);
-
-    return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
-                                       EmitTargetCallFallbackCB, KArgs,
-                                       DeviceID, RTLoc, AllocaIP);
-  }());
-
-  Builder.restoreIP(AfterIP);
+  cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
+                                   EmitTargetCallElse, AllocaIP));
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
     ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
-    SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
+    SmallVectorImpl<Value *> &Args, Value *IfCond,
+    GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
@@ -7455,7 +7478,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
     emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+                   NumThreads, Args, IfCond, GenMapInfoCB, Dependencies,
+                   HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index cdca725b147436..94dce5243d7004 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6232,7 +6232,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   ASSERT_EXPECTED_INIT(
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
-                              Builder.saveIP(), EntryInfo, -1, 0, Inputs,
+                              Builder.saveIP(), EntryInfo, /*NumTeams=*/-1,
+                              /*NumThreads=*/0, Inputs, /*IfCond=*/nullptr,
                               GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
   OMPBuilder.finalize();
@@ -6343,8 +6344,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   ASSERT_EXPECTED_INIT(
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
+                              EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
+                              CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
                               BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
@@ -6500,8 +6501,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
   ASSERT_EXPECTED_INIT(
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
+                              EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
+                              CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
                               BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a364098e0bd8a6..0c637bd32ab3f3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -285,7 +285,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkBare(op, result);
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
-        checkIf(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
         // Privatization clauses are supported, except on some situations, so we
@@ -4112,11 +4111,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
       findAllocaInsertPoint(builder, moduleTranslation);
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
 
+  llvm::Value *ifCond = nullptr;
+  if (Value targetIfCond = targetOp.getIfExpr())
+    ifCond = moduleTranslation.lookupValue(targetIfCond);
+
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
-          defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
-          argAccessorCB, dds, targetOp.getNowait());
+          defaultValTeams, defaultValThreads, kernelInput, ifCond, genMapInfoCB,
+          bodyCB, argAccessorCB, dds, targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
     return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir
new file mode 100644
index 00000000000000..706ad4411438ba
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @target_if_variable(%x : i1) {
+    omp.target if(%x) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_variable(
+  // CHECK-SAME: i1 %[[IF_COND:.*]])
+  // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]
+
+  // CHECK: [[THEN_LABEL]]:
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+  // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+  // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+  // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
+  // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+  // CHECK: [[OFFLOAD_CONT_LABEL]]:
+  // CHECK-NEXT: br label %[[END_LABEL:.*]]
+
+  // CHECK: [[ELSE_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN]]()
+  // CHECK-NEXT: br label %[[END_LABEL]]
+
+  llvm.func @target_if_true() {
+    %0 = llvm.mlir.constant(true) : i1
+    omp.target if(%0) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_true()
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: br label %[[ENTRY:.*]]
+
+  // CHECK: [[ENTRY]]:
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+  // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+  // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+  // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
+  // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+  llvm.func @target_if_false() {
+    %0 = llvm.mlir.constant(false) : i1
+    omp.target if(%0) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_false()
+  // CHECK-NEXT: br label %[[ENTRY:.*]]
+
+  // CHECK: [[ENTRY]]:
+  // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
+}
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 83a0990d631620..4e0925c833c3b7 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -266,17 +266,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
 
 // -----
 
-llvm.func @target_if(%x : i1) {
-  // expected-error at below {{not yet implemented: Unhandled clause if in omp.target operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
-  omp.target if(%x) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 omp.declare_reduction @add_f32 : f32
 init {
 ^bb0(%arg: f32):

``````````

</details>


https://github.com/llvm/llvm-project/pull/122478


More information about the llvm-branch-commits mailing list