[llvm-branch-commits] [llvm] [mlir] [OMPIRBuilder][MLIR] Add support for target 'if' clause (PR #122478)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 14 07:31:37 PST 2025


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/122478

>From 8c348ba2796e08d45fe167d52db0fe047eaafa8a Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 10 Jan 2025 15:40:05 +0000
Subject: [PATCH] [OMPIRBuilder][MLIR] Add support for target 'if' clause

This patch implements support for handling the 'if' clause of OpenMP 'target'
constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the
`omp.target` MLIR operation to make use of this new feature.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  14 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 188 ++++++++++--------
 .../Frontend/OpenMPIRBuilderTest.cpp          |  26 +--
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  11 +-
 mlir/test/Target/LLVMIR/omptarget-if.mlir     |  68 +++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  11 -
 6 files changed, 200 insertions(+), 118 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/omptarget-if.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 7eceec3d8cf8f5..6b6e5bc19d95a4 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2994,27 +2994,29 @@ class OpenMPIRBuilder {
   /// \param Loc where the target data construct was encountered.
   /// \param IsOffloadEntry whether it is an offload entry.
   /// \param CodeGenIP The insertion point where the call to the outlined
-  /// function should be emitted.
+  ///        function should be emitted.
   /// \param EntryInfo The entry information about the function.
   /// \param DefaultAttrs Structure containing the default attributes, including
   ///        numbers of threads and teams to launch the kernel with.
   /// \param RuntimeAttrs Structure containing the runtime numbers of threads
   ///        and teams to launch the kernel with.
+  /// \param IfCond value of the `if` clause.
   /// \param Inputs The input values to the region that will be passed.
-  /// as arguments to the outlined function.
+  ///        as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
   /// \param ArgAccessorFuncCB Callback that will generate accessors
-  /// instructions for passed in target arguments where neccessary
+  ///        instructions for passed in target arguments where neccessary
   /// \param Dependencies A vector of DependData objects that carry
-  // dependency information as passed in the depend clause
-  // \param HasNowait Whether the target construct has a `nowait` clause or not.
+  ///        dependency information as passed in the depend clause
+  /// \param HasNowait Whether the target construct has a `nowait` clause or
+  ///        not.
   InsertPointOrErrorTy createTarget(
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
       TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
-      const TargetKernelRuntimeAttrs &RuntimeAttrs,
+      const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
       TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 3d461f0ad4228c..d29e22c762bd4f 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP,
                const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
                const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
-               Function *OutlinedFn, Constant *OutlinedFnID,
+               Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
                SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
                SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
@@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     return Error::success();
   };
 
-  // If we don't have an ID for the target region, it means an offload entry
-  // wasn't created. In this case we just run the host fallback directly.
-  if (!OutlinedFnID) {
+  auto &&EmitTargetCallElse =
+      [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
     // Assume no error was returned because EmitTargetCallFallbackCB doesn't
     // produce any.
     OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7404,102 +7404,124 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     }());
 
     Builder.restoreIP(AfterIP);
-    return;
-  }
-
-  OpenMPIRBuilder::TargetDataInfo Info(
-      /*RequiresDevicePointerInfo=*/false,
-      /*SeparateBeginEndCalls=*/true);
-
-  OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
-  OpenMPIRBuilder::TargetDataRTArgs RTArgs;
-  OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
-                                         RTArgs, MapInfo,
-                                         /*IsNonContiguous=*/true,
-                                         /*ForEndCall=*/false);
-
-  SmallVector<Value *, 3> NumTeamsC;
-  for (auto [DefaultVal, RuntimeVal] :
-       zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
-    NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
-
-  // Calculate number of threads: 0 if no clauses specified, otherwise it is the
-  // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
-  auto InitMaxThreadsClause = [&Builder](Value *Clause) {
-    if (Clause)
-      Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
-                                     /*isSigned=*/false);
-    return Clause;
-  };
-  auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
-    if (Clause)
-      Result = Result
-                   ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
-                                          Result, Clause)
-                   : Clause;
+    return Error::success();
   };
 
-  // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
-  // the NUM_THREADS clause is overriden by THREAD_LIMIT.
-  SmallVector<Value *, 3> NumThreadsC;
-  Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
-                                ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
-                                : nullptr;
+  auto &&EmitTargetCallThen =
+      [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+          OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
+    OpenMPIRBuilder::TargetDataInfo Info(
+        /*RequiresDevicePointerInfo=*/false,
+        /*SeparateBeginEndCalls=*/true);
+
+    OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
+    OpenMPIRBuilder::TargetDataRTArgs RTArgs;
+    OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
+                                           RTArgs, MapInfo,
+                                           /*IsNonContiguous=*/true,
+                                           /*ForEndCall=*/false);
+
+    SmallVector<Value *, 3> NumTeamsC;
+    for (auto [DefaultVal, RuntimeVal] :
+        zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+      NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+    // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+    // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+    auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+      if (Clause)
+        Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+                                      /*isSigned=*/false);
+      return Clause;
+    };
+    auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+      if (Clause)
+        Result = Result
+                    ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+                                            Result, Clause)
+                    : Clause;
+    };
 
-  for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
-                                              RuntimeAttrs.TargetThreadLimit)) {
-    Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
-    Value *NumThreads = InitMaxThreadsClause(TargetVal);
+    // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+    // the NUM_THREADS clause is overriden by THREAD_LIMIT.
+    SmallVector<Value *, 3> NumThreadsC;
+    Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+                                  ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+                                  : nullptr;
 
-    CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
-    CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+    for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
+                                                RuntimeAttrs.TargetThreadLimit)) {
+      Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+      Value *NumThreads = InitMaxThreadsClause(TargetVal);
 
-    NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
-  }
+      CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+      CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
 
-  unsigned NumTargetItems = Info.NumberOfPtrs;
-  // TODO: Use correct device ID
-  Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
-  uint32_t SrcLocStrSize;
-  Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
-  Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
-                                             llvm::omp::IdentFlag(0), 0);
+      NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+    }
 
-  Value *TripCount = RuntimeAttrs.LoopTripCount
-                         ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
-                                                 Builder.getInt64Ty(),
-                                                 /*isSigned=*/false)
-                         : Builder.getInt64(0);
+    unsigned NumTargetItems = Info.NumberOfPtrs;
+    // TODO: Use correct device ID
+    Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
+    uint32_t SrcLocStrSize;
+    Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+    Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
+                                              llvm::omp::IdentFlag(0), 0);
 
-  // TODO: Use correct DynCGGroupMem
-  Value *DynCGGroupMem = Builder.getInt32(0);
+    Value *TripCount = RuntimeAttrs.LoopTripCount
+                          ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+                                                  Builder.getInt64Ty(),
+                                                  /*isSigned=*/false)
+                          : Builder.getInt64(0);
 
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
-                                            NumTeamsC, NumThreadsC,
-                                            DynCGGroupMem, HasNoWait);
+    // TODO: Use correct DynCGGroupMem
+    Value *DynCGGroupMem = Builder.getInt32(0);
 
-  // Assume no error was returned because TaskBodyCB and
-  // EmitTargetCallFallbackCB don't produce any.
-  OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
-    // The presence of certain clauses on the target directive require the
-    // explicit generation of the target task.
-    if (RequiresOuterTargetTask)
-      return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
-                                       Dependencies, HasNoWait);
+    KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+                                              NumTeamsC, NumThreadsC,
+                                              DynCGGroupMem, HasNoWait);
 
-    return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
-                                       EmitTargetCallFallbackCB, KArgs,
-                                       DeviceID, RTLoc, AllocaIP);
-  }());
+    // Assume no error was returned because TaskBodyCB and
+    // EmitTargetCallFallbackCB don't produce any.
+    OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
+      // The presence of certain clauses on the target directive require the
+      // explicit generation of the target task.
+      if (RequiresOuterTargetTask)
+        return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
+                                         Dependencies, HasNoWait);
+
+      return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
+                                         EmitTargetCallFallbackCB, KArgs,
+                                         DeviceID, RTLoc, AllocaIP);
+    }());
+
+    Builder.restoreIP(AfterIP);
+    return Error::success();
+  };
+
+  // If we don't have an ID for the target region, it means an offload entry
+  // wasn't created. In this case we just run the host fallback directly and
+  // ignore any potential 'if' clauses.
+  if (!OutlinedFnID) {
+    cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
+    return;
+  }
+
+  // If there's no 'if' clause, only generate the kernel launch code path.
+  if (!IfCond) {
+    cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
+    return;
+  }
 
-  Builder.restoreIP(AfterIP);
+  cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
+                                   EmitTargetCallElse, AllocaIP));
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
-    const TargetKernelRuntimeAttrs &RuntimeAttrs,
+    const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7524,7 +7546,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
                    OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
                    HasNowait);
   return Builder.saveIP();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 3b571cce09a4f8..684b842bcbf494 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6243,8 +6243,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
                               Builder.saveIP(), EntryInfo, DefaultAttrs,
-                              RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
-                              SimpleArgAccessorCB));
+                              RuntimeAttrs, /*IfCond=*/nullptr, Inputs,
+                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   OMPBuilder.finalize();
@@ -6402,11 +6402,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
       /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
-  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
-                       OMPBuilder.createTarget(
-                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
-                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  ASSERT_EXPECTED_INIT(
+      OpenMPIRBuilder::InsertPointTy, AfterIP,
+      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                              EntryInfo, DefaultAttrs, RuntimeAttrs,
+                              /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+                              BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
@@ -6774,11 +6775,12 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
       /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
-  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
-                       OMPBuilder.createTarget(
-                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
-                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  ASSERT_EXPECTED_INIT(
+      OpenMPIRBuilder::InsertPointTy, AfterIP,
+      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                              EntryInfo, DefaultAttrs, RuntimeAttrs,
+                              /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+                              BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0be515e63b470c..abef2cb7411aaf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -183,10 +183,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
           result = op.emitError("not yet implemented: host evaluation of loop "
                                 "bounds in omp.target operation");
   };
-  auto checkIf = [&todo](auto op, LogicalResult &result) {
-    if (op.getIfExpr())
-      result = todo("if");
-  };
   auto checkInReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
         op.getInReductionSyms())
@@ -306,7 +302,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
         checkHostEval(op, result);
-        checkIf(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
         checkPrivate(op, result);
@@ -4378,10 +4373,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
       findAllocaInsertPoint(builder, moduleTranslation);
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
 
+  llvm::Value *ifCond = nullptr;
+  if (Value targetIfCond = targetOp.getIfExpr())
+    ifCond = moduleTranslation.lookupValue(targetIfCond);
+
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
-          defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB,
+          defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
           argAccessorCB, dds, targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir
new file mode 100644
index 00000000000000..706ad4411438ba
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @target_if_variable(%x : i1) {
+    omp.target if(%x) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_variable(
+  // CHECK-SAME: i1 %[[IF_COND:.*]])
+  // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]
+
+  // CHECK: [[THEN_LABEL]]:
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+  // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+  // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+  // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
+  // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+  // CHECK: [[OFFLOAD_CONT_LABEL]]:
+  // CHECK-NEXT: br label %[[END_LABEL:.*]]
+
+  // CHECK: [[ELSE_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN]]()
+  // CHECK-NEXT: br label %[[END_LABEL]]
+
+  llvm.func @target_if_true() {
+    %0 = llvm.mlir.constant(true) : i1
+    omp.target if(%0) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_true()
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: br label %[[ENTRY:.*]]
+
+  // CHECK: [[ENTRY]]:
+  // CHECK-NOT: {{^.*}}:
+  // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+  // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+  // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+  // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+  // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
+  // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+  llvm.func @target_if_false() {
+    %0 = llvm.mlir.constant(false) : i1
+    omp.target if(%0) {
+      omp.terminator
+    }
+    llvm.return
+  }
+
+  // CHECK-LABEL: define void @target_if_false()
+  // CHECK-NEXT: br label %[[ENTRY:.*]]
+
+  // CHECK: [[ENTRY]]:
+  // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
+}
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 392a6558dcfa69..c1e30964b25078 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -271,17 +271,6 @@ llvm.func @target_host_eval(%x : i32) {
 
 // -----
 
-llvm.func @target_if(%x : i1) {
-  // expected-error at below {{not yet implemented: Unhandled clause if in omp.target operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
-  omp.target if(%x) {
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 omp.declare_reduction @add_f32 : f32
 init {
 ^bb0(%arg: f32):



More information about the llvm-branch-commits mailing list