[llvm] 9bc8828 - [OMPIRBuilder][MLIR] Add support for target 'if' clause (#122478)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 15 02:16:22 PST 2025


Author: Sergio Afonso
Date: 2025-01-15T10:16:19Z
New Revision: 9bc88280931e3b08adfab6951047191dfe12392b

URL: https://github.com/llvm/llvm-project/commit/9bc88280931e3b08adfab6951047191dfe12392b
DIFF: https://github.com/llvm/llvm-project/commit/9bc88280931e3b08adfab6951047191dfe12392b.diff

LOG: [OMPIRBuilder][MLIR] Add support for target 'if' clause (#122478)

This patch implements support for handling the 'if' clause of OpenMP
'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR
translation of the `omp.target` MLIR operation to make use of this new
feature.

Added: 
    mlir/test/Target/LLVMIR/omptarget-if.mlir

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Target/LLVMIR/openmp-todo.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 7eceec3d8cf8f5..6b6e5bc19d95a4 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2994,27 +2994,29 @@ class OpenMPIRBuilder {
   /// \param Loc where the target data construct was encountered.
   /// \param IsOffloadEntry whether it is an offload entry.
   /// \param CodeGenIP The insertion point where the call to the outlined
-  /// function should be emitted.
+  ///        function should be emitted.
   /// \param EntryInfo The entry information about the function.
   /// \param DefaultAttrs Structure containing the default attributes, including
   ///        numbers of threads and teams to launch the kernel with.
   /// \param RuntimeAttrs Structure containing the runtime numbers of threads
   ///        and teams to launch the kernel with.
+  /// \param IfCond value of the `if` clause.
   /// \param Inputs The input values to the region that will be passed.
-  /// as arguments to the outlined function.
+  ///        as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
   /// \param ArgAccessorFuncCB Callback that will generate accessors
-  /// instructions for passed in target arguments where neccessary
+  ///        instructions for passed in target arguments where neccessary
   /// \param Dependencies A vector of DependData objects that carry
-  // dependency information as passed in the depend clause
-  // \param HasNowait Whether the target construct has a `nowait` clause or not.
+  ///        dependency information as passed in the depend clause
+  /// \param HasNowait Whether the target construct has a `nowait` clause or
+  ///        not.
   InsertPointOrErrorTy createTarget(
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
       TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
-      const TargetKernelRuntimeAttrs &RuntimeAttrs,
+      const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
       TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 3d461f0ad4228c..c6603635d5e281 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5308,8 +5308,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
       Value *Alignment = AlignedItem.second;
       Instruction *loadInst = dyn_cast<Instruction>(AlignedPtr);
       Builder.SetInsertPoint(loadInst->getNextNode());
-      Builder.CreateAlignmentAssumption(F->getDataLayout(),
-                                        AlignedPtr, Alignment);
+      Builder.CreateAlignmentAssumption(F->getDataLayout(), AlignedPtr,
+                                        Alignment);
     }
     Builder.restoreIP(IP);
   }
@@ -5457,16 +5457,16 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
   Loop *L = LI.getLoopFor(CLI->getHeader());
   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
 
-  TargetTransformInfo::UnrollingPreferences UP =
-      gatherUnrollingPreferences(L, SE, TTI,
-                                 /*BlockFrequencyInfo=*/nullptr,
-                                 /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
-                                 /*UserThreshold=*/std::nullopt,
-                                 /*UserCount=*/std::nullopt,
-                                 /*UserAllowPartial=*/true,
-                                 /*UserAllowRuntime=*/true,
-                                 /*UserUpperBound=*/std::nullopt,
-                                 /*UserFullUnrollMaxCount=*/std::nullopt);
+  TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
+      L, SE, TTI,
+      /*BlockFrequencyInfo=*/nullptr,
+      /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
+      /*UserThreshold=*/std::nullopt,
+      /*UserCount=*/std::nullopt,
+      /*UserAllowPartial=*/true,
+      /*UserAllowRuntime=*/true,
+      /*UserUpperBound=*/std::nullopt,
+      /*UserFullUnrollMaxCount=*/std::nullopt);
 
   UP.Force = true;
 
@@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP,
                const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
                const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
-               Function *OutlinedFn, Constant *OutlinedFnID,
+               Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
                SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
                SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
@@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     return Error::success();
   };
 
-  // If we don't have an ID for the target region, it means an offload entry
-  // wasn't created. In this case we just run the host fallback directly.
-  if (!OutlinedFnID) {
+  auto &&EmitTargetCallElse =
+      [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
     // Assume no error was returned because EmitTargetCallFallbackCB doesn't
     // produce any.
     OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7404,102 +7404,126 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     }());
 
     Builder.restoreIP(AfterIP);
-    return;
-  }
-
-  OpenMPIRBuilder::TargetDataInfo Info(
-      /*RequiresDevicePointerInfo=*/false,
-      /*SeparateBeginEndCalls=*/true);
-
-  OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
-  OpenMPIRBuilder::TargetDataRTArgs RTArgs;
-  OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
-                                         RTArgs, MapInfo,
-                                         /*IsNonContiguous=*/true,
-                                         /*ForEndCall=*/false);
-
-  SmallVector<Value *, 3> NumTeamsC;
-  for (auto [DefaultVal, RuntimeVal] :
-       zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
-    NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
-
-  // Calculate number of threads: 0 if no clauses specified, otherwise it is the
-  // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
-  auto InitMaxThreadsClause = [&Builder](Value *Clause) {
-    if (Clause)
-      Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
-                                     /*isSigned=*/false);
-    return Clause;
+    return Error::success();
   };
-  auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
-    if (Clause)
-      Result = Result
-                   ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+
+  auto &&EmitTargetCallThen =
+      [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
+    OpenMPIRBuilder::TargetDataInfo Info(
+        /*RequiresDevicePointerInfo=*/false,
+        /*SeparateBeginEndCalls=*/true);
+
+    OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
+    OpenMPIRBuilder::TargetDataRTArgs RTArgs;
+    OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
+                                           RTArgs, MapInfo,
+                                           /*IsNonContiguous=*/true,
+                                           /*ForEndCall=*/false);
+
+    SmallVector<Value *, 3> NumTeamsC;
+    for (auto [DefaultVal, RuntimeVal] :
+         zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+      NumTeamsC.push_back(RuntimeVal ? RuntimeVal
+                                     : Builder.getInt32(DefaultVal));
+
+    // Calculate number of threads: 0 if no clauses specified, otherwise it is
+    // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+    auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+      if (Clause)
+        Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+                                       /*isSigned=*/false);
+      return Clause;
+    };
+    auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+      if (Clause)
+        Result =
+            Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
                                           Result, Clause)
                    : Clause;
-  };
+    };
 
-  // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
-  // the NUM_THREADS clause is overriden by THREAD_LIMIT.
-  SmallVector<Value *, 3> NumThreadsC;
-  Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
-                                ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
-                                : nullptr;
+    // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+    // the NUM_THREADS clause is overriden by THREAD_LIMIT.
+    SmallVector<Value *, 3> NumThreadsC;
+    Value *MaxThreadsClause =
+        RuntimeAttrs.TeamsThreadLimit.size() == 1
+            ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+            : nullptr;
 
-  for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
-                                              RuntimeAttrs.TargetThreadLimit)) {
-    Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
-    Value *NumThreads = InitMaxThreadsClause(TargetVal);
+    for (auto [TeamsVal, TargetVal] : zip_equal(
+             RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+      Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+      Value *NumThreads = InitMaxThreadsClause(TargetVal);
 
-    CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
-    CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+      CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+      CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
 
-    NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
-  }
+      NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+    }
 
-  unsigned NumTargetItems = Info.NumberOfPtrs;
-  // TODO: Use correct device ID
-  Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
-  uint32_t SrcLocStrSize;
-  Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
-  Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
-                                             llvm::omp::IdentFlag(0), 0);
+    unsigned NumTargetItems = Info.NumberOfPtrs;
+    // TODO: Use correct device ID
+    Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
+    uint32_t SrcLocStrSize;
+    Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+    Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
+                                               llvm::omp::IdentFlag(0), 0);
 
-  Value *TripCount = RuntimeAttrs.LoopTripCount
-                         ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
-                                                 Builder.getInt64Ty(),
-                                                 /*isSigned=*/false)
-                         : Builder.getInt64(0);
+    Value *TripCount = RuntimeAttrs.LoopTripCount
+                           ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+                                                   Builder.getInt64Ty(),
+                                                   /*isSigned=*/false)
+                           : Builder.getInt64(0);
 
-  // TODO: Use correct DynCGGroupMem
-  Value *DynCGGroupMem = Builder.getInt32(0);
+    // TODO: Use correct DynCGGroupMem
+    Value *DynCGGroupMem = Builder.getInt32(0);
 
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
-                                            NumTeamsC, NumThreadsC,
-                                            DynCGGroupMem, HasNoWait);
+    KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+                                              NumTeamsC, NumThreadsC,
+                                              DynCGGroupMem, HasNoWait);
 
-  // Assume no error was returned because TaskBodyCB and
-  // EmitTargetCallFallbackCB don't produce any.
-  OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
-    // The presence of certain clauses on the target directive require the
-    // explicit generation of the target task.
-    if (RequiresOuterTargetTask)
-      return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
-                                       Dependencies, HasNoWait);
+    // Assume no error was returned because TaskBodyCB and
+    // EmitTargetCallFallbackCB don't produce any.
+    OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
+      // The presence of certain clauses on the target directive require the
+      // explicit generation of the target task.
+      if (RequiresOuterTargetTask)
+        return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
+                                         Dependencies, HasNoWait);
+
+      return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
+                                         EmitTargetCallFallbackCB, KArgs,
+                                         DeviceID, RTLoc, AllocaIP);
+    }());
+
+    Builder.restoreIP(AfterIP);
+    return Error::success();
+  };
 
-    return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
-                                       EmitTargetCallFallbackCB, KArgs,
-                                       DeviceID, RTLoc, AllocaIP);
-  }());
+  // If we don't have an ID for the target region, it means an offload entry
+  // wasn't created. In this case we just run the host fallback directly and
+  // ignore any potential 'if' clauses.
+  if (!OutlinedFnID) {
+    cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
+    return;
+  }
+
+  // If there's no 'if' clause, only generate the kernel launch code path.
+  if (!IfCond) {
+    cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
+    return;
+  }
 
-  Builder.restoreIP(AfterIP);
+  cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
+                                   EmitTargetCallElse, AllocaIP));
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
-    const TargetKernelRuntimeAttrs &RuntimeAttrs,
+    const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7524,7 +7548,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
                    OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
                    HasNowait);
   return Builder.saveIP();

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 3b571cce09a4f8..a7b513bdfdc667 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6243,8 +6243,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
                               Builder.saveIP(), EntryInfo, DefaultAttrs,
-                              RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
-                              SimpleArgAccessorCB));
+                              RuntimeAttrs, /*IfCond=*/nullptr, Inputs,
+                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   OMPBuilder.finalize();
@@ -6402,11 +6402,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
       /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
-  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
-                       OMPBuilder.createTarget(
-                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
-                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  ASSERT_EXPECTED_INIT(
+      OpenMPIRBuilder::InsertPointTy, AfterIP,
+      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                              EntryInfo, DefaultAttrs, RuntimeAttrs,
+                              /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+                              BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
@@ -6561,8 +6562,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
                               Builder.saveIP(), EntryInfo, DefaultAttrs,
-                              RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
-                              SimpleArgAccessorCB));
+                              RuntimeAttrs, /*IfCond=*/nullptr, Inputs,
+                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   OMPBuilder.finalize();
@@ -6660,11 +6661,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
       /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD,
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
-  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
-                       OMPBuilder.createTarget(
-                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
-                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  ASSERT_EXPECTED_INIT(
+      OpenMPIRBuilder::InsertPointTy, AfterIP,
+      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                              EntryInfo, DefaultAttrs, RuntimeAttrs,
+                              /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+                              BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
@@ -6774,11 +6776,12 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
       /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
-  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
-                       OMPBuilder.createTarget(
-                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
-                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  ASSERT_EXPECTED_INIT(
+      OpenMPIRBuilder::InsertPointTy, AfterIP,
+      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                              EntryInfo, DefaultAttrs, RuntimeAttrs,
+                              /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+                              BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0be515e63b470c..abef2cb7411aaf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -183,10 +183,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
           result = op.emitError("not yet implemented: host evaluation of loop "
                                 "bounds in omp.target operation");
   };
-  auto checkIf = [&todo](auto op, LogicalResult &result) {
-    if (op.getIfExpr())
-      result = todo("if");
-  };
   auto checkInReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
         op.getInReductionSyms())
@@ -306,7 +302,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
         checkHostEval(op, result);
-        checkIf(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
         checkPrivate(op, result);
@@ -4378,10 +4373,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
       findAllocaInsertPoint(builder, moduleTranslation);
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
 
+  llvm::Value *ifCond = nullptr;
+  if (Value targetIfCond = targetOp.getIfExpr())
+    ifCond = moduleTranslation.lookupValue(targetIfCond);
+
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
-          defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB,
+          defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
           argAccessorCB, dds, targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))

diff  --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir
new file mode 100644
index 00000000000000..706ad4411438ba
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @target_if_variable(%x : i1) {
+    omp.target if(%x) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_variable(
+  // CHECK-SAME: i1 %[[IF_COND:.*]])
+  // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]
+
+  // CHECK: [[THEN_LABEL]]:
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+  // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+  // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+  // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
+  // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+  // CHECK: [[OFFLOAD_CONT_LABEL]]:
+  // CHECK-NEXT: br label %[[END_LABEL:.*]]
+
+  // CHECK: [[ELSE_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN]]()
+  // CHECK-NEXT: br label %[[END_LABEL]]
+
+  llvm.func @target_if_true() {
+    %0 = llvm.mlir.constant(true) : i1
+    omp.target if(%0) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_true()
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: br label %[[ENTRY:.*]]
+
+  // CHECK: [[ENTRY]]:
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+  // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+  // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+  // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
+  // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+  llvm.func @target_if_false() {
+    %0 = llvm.mlir.constant(false) : i1
+    omp.target if(%0) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_false()
+  // CHECK-NEXT: br label %[[ENTRY:.*]]
+
+  // CHECK: [[ENTRY]]:
+  // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
+}
+

diff  --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 392a6558dcfa69..c1e30964b25078 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -271,17 +271,6 @@ llvm.func @target_host_eval(%x : i32) {
 
 // -----
 
-llvm.func @target_if(%x : i1) {
-  // expected-error at below {{not yet implemented: Unhandled clause if in omp.target operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
-  omp.target if(%x) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 omp.declare_reduction @add_f32 : f32
 init {
 ^bb0(%arg: f32):


        


More information about the llvm-commits mailing list