[llvm] [mlir] [OMPIRBuilder][MLIR] Add support for target 'if' clause (PR #122478)

Sergio Afonso via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 14 08:12:46 PST 2025


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/122478

>From cb022ad010a1ef3e1faa418cec8b5c18c9ff067f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 10 Jan 2025 15:40:05 +0000
Subject: [PATCH] [OMPIRBuilder][MLIR] Add support for target 'if' clause

This patch implements support for handling the 'if' clause of OpenMP 'target'
constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the
`omp.target` MLIR operation to make use of this new feature.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  14 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 210 ++++++++++--------
 .../Frontend/OpenMPIRBuilderTest.cpp          |  26 ++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  11 +-
 mlir/test/Target/LLVMIR/omptarget-if.mlir     |  68 ++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  11 -
 6 files changed, 212 insertions(+), 128 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/omptarget-if.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 7eceec3d8cf8f5..6b6e5bc19d95a4 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2994,27 +2994,29 @@ class OpenMPIRBuilder {
   /// \param Loc where the target data construct was encountered.
   /// \param IsOffloadEntry whether it is an offload entry.
   /// \param CodeGenIP The insertion point where the call to the outlined
-  /// function should be emitted.
+  ///        function should be emitted.
   /// \param EntryInfo The entry information about the function.
   /// \param DefaultAttrs Structure containing the default attributes, including
   ///        numbers of threads and teams to launch the kernel with.
   /// \param RuntimeAttrs Structure containing the runtime numbers of threads
   ///        and teams to launch the kernel with.
+  /// \param IfCond value of the `if` clause.
   /// \param Inputs The input values to the region that will be passed.
-  /// as arguments to the outlined function.
+  ///        as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
   /// \param ArgAccessorFuncCB Callback that will generate accessors
-  /// instructions for passed in target arguments where neccessary
+  ///        instructions for passed in target arguments where neccessary
   /// \param Dependencies A vector of DependData objects that carry
-  // dependency information as passed in the depend clause
-  // \param HasNowait Whether the target construct has a `nowait` clause or not.
+  ///        dependency information as passed in the depend clause
+  /// \param HasNowait Whether the target construct has a `nowait` clause or
+  ///        not.
   InsertPointOrErrorTy createTarget(
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
       TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
-      const TargetKernelRuntimeAttrs &RuntimeAttrs,
+      const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
       TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 3d461f0ad4228c..c6603635d5e281 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5308,8 +5308,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
       Value *Alignment = AlignedItem.second;
       Instruction *loadInst = dyn_cast<Instruction>(AlignedPtr);
       Builder.SetInsertPoint(loadInst->getNextNode());
-      Builder.CreateAlignmentAssumption(F->getDataLayout(),
-                                        AlignedPtr, Alignment);
+      Builder.CreateAlignmentAssumption(F->getDataLayout(), AlignedPtr,
+                                        Alignment);
     }
     Builder.restoreIP(IP);
   }
@@ -5457,16 +5457,16 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
   Loop *L = LI.getLoopFor(CLI->getHeader());
   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
 
-  TargetTransformInfo::UnrollingPreferences UP =
-      gatherUnrollingPreferences(L, SE, TTI,
-                                 /*BlockFrequencyInfo=*/nullptr,
-                                 /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
-                                 /*UserThreshold=*/std::nullopt,
-                                 /*UserCount=*/std::nullopt,
-                                 /*UserAllowPartial=*/true,
-                                 /*UserAllowRuntime=*/true,
-                                 /*UserUpperBound=*/std::nullopt,
-                                 /*UserFullUnrollMaxCount=*/std::nullopt);
+  TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
+      L, SE, TTI,
+      /*BlockFrequencyInfo=*/nullptr,
+      /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
+      /*UserThreshold=*/std::nullopt,
+      /*UserCount=*/std::nullopt,
+      /*UserAllowPartial=*/true,
+      /*UserAllowRuntime=*/true,
+      /*UserUpperBound=*/std::nullopt,
+      /*UserFullUnrollMaxCount=*/std::nullopt);
 
   UP.Force = true;
 
@@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP,
                const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
                const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
-               Function *OutlinedFn, Constant *OutlinedFnID,
+               Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
                SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
                SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
@@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     return Error::success();
   };
 
-  // If we don't have an ID for the target region, it means an offload entry
-  // wasn't created. In this case we just run the host fallback directly.
-  if (!OutlinedFnID) {
+  auto &&EmitTargetCallElse =
+      [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
     // Assume no error was returned because EmitTargetCallFallbackCB doesn't
     // produce any.
     OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7404,102 +7404,126 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     }());
 
     Builder.restoreIP(AfterIP);
-    return;
-  }
-
-  OpenMPIRBuilder::TargetDataInfo Info(
-      /*RequiresDevicePointerInfo=*/false,
-      /*SeparateBeginEndCalls=*/true);
-
-  OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
-  OpenMPIRBuilder::TargetDataRTArgs RTArgs;
-  OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
-                                         RTArgs, MapInfo,
-                                         /*IsNonContiguous=*/true,
-                                         /*ForEndCall=*/false);
-
-  SmallVector<Value *, 3> NumTeamsC;
-  for (auto [DefaultVal, RuntimeVal] :
-       zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
-    NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
-
-  // Calculate number of threads: 0 if no clauses specified, otherwise it is the
-  // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
-  auto InitMaxThreadsClause = [&Builder](Value *Clause) {
-    if (Clause)
-      Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
-                                     /*isSigned=*/false);
-    return Clause;
+    return Error::success();
   };
-  auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
-    if (Clause)
-      Result = Result
-                   ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+
+  auto &&EmitTargetCallThen =
+      [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
+    OpenMPIRBuilder::TargetDataInfo Info(
+        /*RequiresDevicePointerInfo=*/false,
+        /*SeparateBeginEndCalls=*/true);
+
+    OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
+    OpenMPIRBuilder::TargetDataRTArgs RTArgs;
+    OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
+                                           RTArgs, MapInfo,
+                                           /*IsNonContiguous=*/true,
+                                           /*ForEndCall=*/false);
+
+    SmallVector<Value *, 3> NumTeamsC;
+    for (auto [DefaultVal, RuntimeVal] :
+         zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+      NumTeamsC.push_back(RuntimeVal ? RuntimeVal
+                                     : Builder.getInt32(DefaultVal));
+
+    // Calculate number of threads: 0 if no clauses specified, otherwise it is
+    // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+    auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+      if (Clause)
+        Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+                                       /*isSigned=*/false);
+      return Clause;
+    };
+    auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+      if (Clause)
+        Result =
+            Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
                                           Result, Clause)
                    : Clause;
-  };
+    };
 
-  // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
-  // the NUM_THREADS clause is overriden by THREAD_LIMIT.
-  SmallVector<Value *, 3> NumThreadsC;
-  Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
-                                ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
-                                : nullptr;
+    // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+    // the NUM_THREADS clause is overriden by THREAD_LIMIT.
+    SmallVector<Value *, 3> NumThreadsC;
+    Value *MaxThreadsClause =
+        RuntimeAttrs.TeamsThreadLimit.size() == 1
+            ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+            : nullptr;
 
-  for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
-                                              RuntimeAttrs.TargetThreadLimit)) {
-    Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
-    Value *NumThreads = InitMaxThreadsClause(TargetVal);
+    for (auto [TeamsVal, TargetVal] : zip_equal(
+             RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+      Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+      Value *NumThreads = InitMaxThreadsClause(TargetVal);
 
-    CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
-    CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+      CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+      CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
 
-    NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
-  }
+      NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+    }
 
-  unsigned NumTargetItems = Info.NumberOfPtrs;
-  // TODO: Use correct device ID
-  Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
-  uint32_t SrcLocStrSize;
-  Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
-  Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
-                                             llvm::omp::IdentFlag(0), 0);
+    unsigned NumTargetItems = Info.NumberOfPtrs;
+    // TODO: Use correct device ID
+    Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
+    uint32_t SrcLocStrSize;
+    Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+    Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
+                                               llvm::omp::IdentFlag(0), 0);
 
-  Value *TripCount = RuntimeAttrs.LoopTripCount
-                         ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
-                                                 Builder.getInt64Ty(),
-                                                 /*isSigned=*/false)
-                         : Builder.getInt64(0);
+    Value *TripCount = RuntimeAttrs.LoopTripCount
+                           ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+                                                   Builder.getInt64Ty(),
+                                                   /*isSigned=*/false)
+                           : Builder.getInt64(0);
 
-  // TODO: Use correct DynCGGroupMem
-  Value *DynCGGroupMem = Builder.getInt32(0);
+    // TODO: Use correct DynCGGroupMem
+    Value *DynCGGroupMem = Builder.getInt32(0);
 
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
-                                            NumTeamsC, NumThreadsC,
-                                            DynCGGroupMem, HasNoWait);
+    KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+                                              NumTeamsC, NumThreadsC,
+                                              DynCGGroupMem, HasNoWait);
 
-  // Assume no error was returned because TaskBodyCB and
-  // EmitTargetCallFallbackCB don't produce any.
-  OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
-    // The presence of certain clauses on the target directive require the
-    // explicit generation of the target task.
-    if (RequiresOuterTargetTask)
-      return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
-                                       Dependencies, HasNoWait);
+    // Assume no error was returned because TaskBodyCB and
+    // EmitTargetCallFallbackCB don't produce any.
+    OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
+      // The presence of certain clauses on the target directive require the
+      // explicit generation of the target task.
+      if (RequiresOuterTargetTask)
+        return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
+                                         Dependencies, HasNoWait);
+
+      return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
+                                         EmitTargetCallFallbackCB, KArgs,
+                                         DeviceID, RTLoc, AllocaIP);
+    }());
+
+    Builder.restoreIP(AfterIP);
+    return Error::success();
+  };
 
-    return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
-                                       EmitTargetCallFallbackCB, KArgs,
-                                       DeviceID, RTLoc, AllocaIP);
-  }());
+  // If we don't have an ID for the target region, it means an offload entry
+  // wasn't created. In this case we just run the host fallback directly and
+  // ignore any potential 'if' clauses.
+  if (!OutlinedFnID) {
+    cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
+    return;
+  }
+
+  // If there's no 'if' clause, only generate the kernel launch code path.
+  if (!IfCond) {
+    cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
+    return;
+  }
 
-  Builder.restoreIP(AfterIP);
+  cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
+                                   EmitTargetCallElse, AllocaIP));
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
-    const TargetKernelRuntimeAttrs &RuntimeAttrs,
+    const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7524,7 +7548,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
                    OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
                    HasNowait);
   return Builder.saveIP();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 3b571cce09a4f8..684b842bcbf494 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6243,8 +6243,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
                               Builder.saveIP(), EntryInfo, DefaultAttrs,
-                              RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
-                              SimpleArgAccessorCB));
+                              RuntimeAttrs, /*IfCond=*/nullptr, Inputs,
+                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   OMPBuilder.finalize();
@@ -6402,11 +6402,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
       /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
-  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
-                       OMPBuilder.createTarget(
-                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
-                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  ASSERT_EXPECTED_INIT(
+      OpenMPIRBuilder::InsertPointTy, AfterIP,
+      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                              EntryInfo, DefaultAttrs, RuntimeAttrs,
+                              /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+                              BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
@@ -6774,11 +6775,12 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
       /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
-  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
-                       OMPBuilder.createTarget(
-                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
-                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  ASSERT_EXPECTED_INIT(
+      OpenMPIRBuilder::InsertPointTy, AfterIP,
+      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                              EntryInfo, DefaultAttrs, RuntimeAttrs,
+                              /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+                              BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0be515e63b470c..abef2cb7411aaf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -183,10 +183,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
           result = op.emitError("not yet implemented: host evaluation of loop "
                                 "bounds in omp.target operation");
   };
-  auto checkIf = [&todo](auto op, LogicalResult &result) {
-    if (op.getIfExpr())
-      result = todo("if");
-  };
   auto checkInReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
         op.getInReductionSyms())
@@ -306,7 +302,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
         checkHostEval(op, result);
-        checkIf(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
         checkPrivate(op, result);
@@ -4378,10 +4373,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
       findAllocaInsertPoint(builder, moduleTranslation);
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
 
+  llvm::Value *ifCond = nullptr;
+  if (Value targetIfCond = targetOp.getIfExpr())
+    ifCond = moduleTranslation.lookupValue(targetIfCond);
+
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
-          defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB,
+          defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
           argAccessorCB, dds, targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir
new file mode 100644
index 00000000000000..706ad4411438ba
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @target_if_variable(%x : i1) {
+    omp.target if(%x) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_variable(
+  // CHECK-SAME: i1 %[[IF_COND:.*]])
+  // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]
+
+  // CHECK: [[THEN_LABEL]]:
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+  // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+  // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+  // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
+  // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+  // CHECK: [[OFFLOAD_CONT_LABEL]]:
+  // CHECK-NEXT: br label %[[END_LABEL:.*]]
+
+  // CHECK: [[ELSE_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN]]()
+  // CHECK-NEXT: br label %[[END_LABEL]]
+
+  llvm.func @target_if_true() {
+    %0 = llvm.mlir.constant(true) : i1
+    omp.target if(%0) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_true()
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: br label %[[ENTRY:.*]]
+
+  // CHECK: [[ENTRY]]:
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+  // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+  // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+  // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
+  // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+  llvm.func @target_if_false() {
+    %0 = llvm.mlir.constant(false) : i1
+    omp.target if(%0) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_false()
+  // CHECK-NEXT: br label %[[ENTRY:.*]]
+
+  // CHECK: [[ENTRY]]:
+  // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
+}
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 392a6558dcfa69..c1e30964b25078 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -271,17 +271,6 @@ llvm.func @target_host_eval(%x : i32) {
 
 // -----
 
-llvm.func @target_if(%x : i1) {
-  // expected-error at below {{not yet implemented: Unhandled clause if in omp.target operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
-  omp.target if(%x) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 omp.declare_reduction @add_f32 : f32
 init {
 ^bb0(%arg: f32):



More information about the llvm-commits mailing list