[Mlir-commits] [llvm] [mlir] [flang][OpenMP] Support `target ... nowait` (PR #111823)

Kareem Ergawy llvmlistbot at llvm.org
Thu Oct 10 05:08:12 PDT 2024


https://github.com/ergawy created https://github.com/llvm/llvm-project/pull/111823

Adds MLIR to LLVM lowering support for `target ... nowait`. This leverages the already existings code-gen patterns for `task` by treating `target ... nowait` as `task ... if(1)` and `taret` (without `nowait`) as `task ... if(0)`; similar to what clang does.

>From d0cc4d8e163d7e3ab0e4e2d350d0d86e90e5458d Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Thu, 10 Oct 2024 04:51:12 -0500
Subject: [PATCH] [flang][OpenMP] Support `target ... nowait`

Adds MLIR to LLVM lowering support for `target ... nowait`. This
leverages the already existings code-gen patterns for `task` by treating
`target ... nowait` as `task ... if(1)` and `taret` (without `nowait`) as
`task ... if(0)`; similar to what clang does.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 19 ++++---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 47 ++++++++++------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  7 +--
 .../Target/LLVMIR/omptarget-nowait-llvm.mlir  | 54 ++++++++-----------
 .../omptarget-nowait-unsupported-llvm.mlir    | 39 ++++++++++++++
 5 files changed, 102 insertions(+), 64 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 1b8a6e47b3baf8..5d408ec6ac739f 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2875,16 +2875,15 @@ class OpenMPIRBuilder {
   /// instructions for passed in target arguments where neccessary
   /// \param Dependencies A vector of DependData objects that carry
   // dependency information as passed in the depend clause
-  InsertPointTy
-  createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
-               OpenMPIRBuilder::InsertPointTy AllocaIP,
-               OpenMPIRBuilder::InsertPointTy CodeGenIP,
-               TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
-               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
-               GenMapInfoCallbackTy GenMapInfoCB,
-               TargetBodyGenCallbackTy BodyGenCB,
-               TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
-               SmallVector<DependData> Dependencies = {});
+  InsertPointTy createTarget(
+      const LocationDescription &Loc, bool IsOffloadEntry,
+      OpenMPIRBuilder::InsertPointTy AllocaIP,
+      OpenMPIRBuilder::InsertPointTy CodeGenIP,
+      TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+      ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+      GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
+      TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+      SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
 
   /// Returns __kmpc_for_static_init_* runtime function for the specified
   /// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 47cc6ff7655caf..4da05bed54757f 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6969,7 +6969,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
 
   OI.ExitBB = Builder.saveIP().getBlock();
   OI.PostOutlineCB = [this, ToBeDeleted, Dependencies,
-                      HasNoWait](Function &OutlinedFn) mutable {
+                      HasNoWait, DeviceID](Function &OutlinedFn) mutable {
     assert(OutlinedFn.getNumUses() == 1 &&
            "there must be a single user for the outlined function");
 
@@ -6989,9 +6989,14 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
         getOrCreateSrcLocStr(LocationDescription(Builder), SrcLocStrSize);
     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
 
-    // @__kmpc_omp_task_alloc
+    // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
+    //
+    // If `HasNoWait == true`, we call  @__kmpc_omp_target_task_alloc to provide
+    // the DeviceID to the deferred task.
     Function *TaskAllocFn =
-        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
+        !HasNoWait ? getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+                   : getOrCreateRuntimeFunctionPtr(
+                         OMPRTL___kmpc_omp_target_task_alloc);
 
     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
     // call.
@@ -7032,10 +7037,18 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
     // Emit the @__kmpc_omp_task_alloc runtime call
     // The runtime call returns a pointer to an area where the task captured
     // variables must be copied before the task is run (TaskData)
-    CallInst *TaskData = Builder.CreateCall(
-        TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
-                      /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
-                      /*task_func=*/ProxyFn});
+    CallInst *TaskData = nullptr;
+
+    SmallVector<llvm::Value *> TaskAllocArgs = {
+        /*loc_ref=*/Ident,        /*gtid=*/ThreadID,
+        /*flags=*/Flags,
+        /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
+        /*task_func=*/ProxyFn};
+
+    if (HasNoWait)
+      TaskAllocArgs.push_back(DeviceID);
+
+    TaskData = Builder.CreateCall(TaskAllocFn, TaskAllocArgs);
 
     if (HasShareds) {
       Value *Shareds = StaleCI->getArgOperand(1);
@@ -7118,13 +7131,14 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
   emitOffloadingArraysArgument(Builder, RTArgs, Info, ForEndCall);
 }
 
-static void emitTargetCall(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-    OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
-    Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
-    ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
-    OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
-    SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
+static void
+emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+               OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
+               Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
+               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+               OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
+               SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
+               bool HasNoWait = false) {
   // Generate a function call to the host fallback implementation of the target
   // region. This is called by the host when no offload entry was generated for
   // the target region and when the offloading call fails at runtime.
@@ -7135,7 +7149,6 @@ static void emitTargetCall(
     return Builder.saveIP();
   };
 
-  bool HasNoWait = false;
   bool HasDependencies = Dependencies.size() > 0;
   bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
 
@@ -7211,7 +7224,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
-    SmallVector<DependData> Dependencies) {
+    SmallVector<DependData> Dependencies, bool HasNowait) {
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -7232,7 +7245,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
     emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependencies);
+                   NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 19d80fbbd699b0..745d636acfad5d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3242,11 +3242,6 @@ static bool targetOpSupported(Operation &opInst) {
     return false;
   }
 
-  if (targetOp.getNowait()) {
-    opInst.emitError("Nowait clause not yet supported");
-    return false;
-  }
-
   if (!targetOp.getAllocateVars().empty() ||
       !targetOp.getAllocatorVars().empty()) {
     opInst.emitError("Allocate clause not yet supported");
@@ -3569,7 +3564,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
       ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
       defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
-      argAccessorCB, dds));
+      argAccessorCB, dds, targetOp.getNowait()));
 
   // Remap access operations to declare target reference pointers for the
   // device, essentially generating extra loadop's as necessary
diff --git a/mlir/test/Target/LLVMIR/omptarget-nowait-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-nowait-llvm.mlir
index 1e2fbe86d13c47..b487b31d544777 100644
--- a/mlir/test/Target/LLVMIR/omptarget-nowait-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-nowait-llvm.mlir
@@ -1,39 +1,31 @@
-// RUN: not mlir-translate -mlir-to-llvmir -split-input-file %s 2>&1 | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir %s 2>&1 | FileCheck %s
 
-llvm.func @_QPopenmp_target_data_update() {
-  %0 = llvm.mlir.constant(1 : i64) : i64
-  %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
-  %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32)   map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
+// Set a dummy target triple to enable target region outlining.
+module attributes {omp.target_triples = ["dummy-target-triple"]} {
+  llvm.func @_QPfoo() {
+    %0 = llvm.mlir.constant(1 : i64) : i64
+    %1 = llvm.alloca %0 x i32 : (i64) -> !llvm.ptr
+    %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32) map_clauses(implicit) capture(ByCopy) -> !llvm.ptr
+    omp.target nowait map_entries(%2 -> %arg0 : !llvm.ptr) {
+      %3 = llvm.mlir.constant(2 : i32) : i32
+      llvm.store %3, %arg0 : i32, !llvm.ptr
+      omp.terminator
+    }
+    llvm.return
+  }
 
-  // CHECK: error: `nowait` is not supported yet
-  omp.target_update map_entries(%2 : !llvm.ptr) nowait
-
-  llvm.return
-}
-
-// -----
-
-llvm.func @_QPopenmp_target_data_enter() {
-  %0 = llvm.mlir.constant(1 : i64) : i64
-  %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
-  %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32)   map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
-
-  // CHECK: error: `nowait` is not supported yet
-  omp.target_enter_data map_entries(%2 : !llvm.ptr) nowait
-
-  llvm.return
-}
 
+// CHECK: define void @_QPfoo() {
 
-// -----
+// CHECK:   %[[TASK:.*]] = call ptr @__kmpc_omp_target_task_alloc
+// CHECK-SAME:     (ptr @{{.*}}, i32 %{{.*}}, i32 {{.*}}, i64 {{.*}}, i64 {{.*}}, ptr
+// CHECK-SAME:     @[[TASK_PROXY_FUNC:.*]], i64 {{.*}})
 
-llvm.func @_QPopenmp_target_data_exit() {
-  %0 = llvm.mlir.constant(1 : i64) : i64
-  %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
-  %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32)   map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
+// CHECK:   call i32 @__kmpc_omp_task(ptr {{.*}}, i32 %{{.*}}, ptr %[[TASK]])
+// CHECK: }
 
-  // CHECK: error: `nowait` is not supported yet
-  omp.target_exit_data map_entries(%2 : !llvm.ptr) nowait
 
-  llvm.return
+// CHECK: define internal void @[[TASK_PROXY_FUNC]](i32 %{{.*}}, ptr %{{.*}}) {
+// CHECK:   call void @_QPfoo..omp_par(i32 %{{.*}}, ptr %{{.*}})
+// CHECK: }
 }
diff --git a/mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir
new file mode 100644
index 00000000000000..1e2fbe86d13c47
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-nowait-unsupported-llvm.mlir
@@ -0,0 +1,39 @@
+// RUN: not mlir-translate -mlir-to-llvmir -split-input-file %s 2>&1 | FileCheck %s
+
+llvm.func @_QPopenmp_target_data_update() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
+  %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32)   map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
+
+  // CHECK: error: `nowait` is not supported yet
+  omp.target_update map_entries(%2 : !llvm.ptr) nowait
+
+  llvm.return
+}
+
+// -----
+
+llvm.func @_QPopenmp_target_data_enter() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
+  %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32)   map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
+
+  // CHECK: error: `nowait` is not supported yet
+  omp.target_enter_data map_entries(%2 : !llvm.ptr) nowait
+
+  llvm.return
+}
+
+
+// -----
+
+llvm.func @_QPopenmp_target_data_exit() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "i", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFopenmp_target_dataEi"} : (i64) -> !llvm.ptr
+  %2 = omp.map.info var_ptr(%1 : !llvm.ptr, i32)   map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
+
+  // CHECK: error: `nowait` is not supported yet
+  omp.target_exit_data map_entries(%2 : !llvm.ptr) nowait
+
+  llvm.return
+}



More information about the Mlir-commits mailing list