[libc-commits] [libc] [OpenMPIRBuilder] Remove wrapper function in `createTask` (PR #67723)

via libc-commits libc-commits at lists.llvm.org
Mon Oct 2 07:23:49 PDT 2023


https://github.com/shraiysh updated https://github.com/llvm/llvm-project/pull/67723

>From 6aabc3c10ea2d587120b74966b7ce96f9b8167af Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Thu, 28 Sep 2023 13:35:07 -0500
Subject: [PATCH 1/2] [OpenMPIRBuilder] Remove wrapper function in `createTask`

This patch removes the wrapper function in
`OpenMPIRBuilder::createTask`. The outlined function is directly of the
form that is expected by the runtime library calls. This also fixes the
global thread ID argument, which should be used whenever
`kmpc_global_thread_num()` is called inside the outlined function.
---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 106 ++++++++----------
 .../Frontend/OpenMPIRBuilderTest.cpp          |  56 +++++----
 mlir/test/Target/LLVMIR/openmp-llvm.mlir      |  51 +++------
 3 files changed, 99 insertions(+), 114 deletions(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 9c70d384e55db2b..54012b488c6b671 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -35,6 +35,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
@@ -1496,6 +1497,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
                             bool Tied, Value *Final, Value *IfCondition,
                             SmallVector<DependData> Dependencies) {
+  // We create a temporary i32 value that will represent the global tid after
+  // outlining.
+  SmallVector<Instruction *, 4> ToBeDeleted;
+  Builder.restoreIP(AllocaIP);
+  AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
+  LoadInst *TID = Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use");
+  ToBeDeleted.append({TID, TIDAddr});
+
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -1523,41 +1532,27 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
   BasicBlock *TaskAllocaBB =
       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
 
+  // Fake use of TID
+  Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
+  BinaryOperator *AddInst =
+      dyn_cast<BinaryOperator>(Builder.CreateAdd(TID, Builder.getInt32(10)));
+  ToBeDeleted.push_back(AddInst);
+
   OutlineInfo OI;
   OI.EntryBB = TaskAllocaBB;
   OI.OuterAllocaBB = AllocaIP.getBlock();
   OI.ExitBB = TaskExitBB;
-  OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition,
-                      Dependencies](Function &OutlinedFn) {
-    // The input IR here looks like the following-
-    // ```
-    // func @current_fn() {
-    //   outlined_fn(%args)
-    // }
-    // func @outlined_fn(%args) { ... }
-    // ```
-    //
-    // This is changed to the following-
-    //
-    // ```
-    // func @current_fn() {
-    //   runtime_call(..., wrapper_fn, ...)
-    // }
-    // func @wrapper_fn(..., %args) {
-    //   outlined_fn(%args)
-    // }
-    // func @outlined_fn(%args) { ... }
-    // ```
-
-    // The stale call instruction will be replaced with a new call instruction
-    // for runtime call with a wrapper function.
+  OI.ExcludeArgsFromAggregate = {TID};
+  OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
+                      TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) {
+    // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.getNumUses() == 1 &&
            "there must be a single user for the outlined function");
     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
 
     // HasShareds is true if any variables are captured in the outlined region,
     // false otherwise.
-    bool HasShareds = StaleCI->arg_size() > 0;
+    bool HasShareds = StaleCI->arg_size() > 1;
     Builder.SetInsertPoint(StaleCI);
 
     // Gather the arguments for emitting the runtime call for
@@ -1595,7 +1590,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
     Value *SharedsSize = Builder.getInt64(0);
     if (HasShareds) {
       AllocaInst *ArgStructAlloca =
-          dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
+          dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
       assert(ArgStructAlloca &&
              "Unable to find the alloca instruction corresponding to arguments "
              "for extracted function");
@@ -1606,31 +1601,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
       SharedsSize =
           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
     }
-
-    // Argument - task_entry (the wrapper function)
-    // If the outlined function has some captured variables (i.e. HasShareds is
-    // true), then the wrapper function will have an additional argument (the
-    // struct containing captured variables). Otherwise, no such argument will
-    // be present.
-    SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
-    if (HasShareds)
-      WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
-    FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
-        (Twine(OutlinedFn.getName()) + ".wrapper").str(),
-        FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false));
-    Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
-
     // Emit the @__kmpc_omp_task_alloc runtime call
     // The runtime call returns a pointer to an area where the task captured
     // variables must be copied before the task is run (TaskData)
     CallInst *TaskData = Builder.CreateCall(
         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
-                      /*task_func=*/WrapperFunc});
+                      /*task_func=*/&OutlinedFn});
 
     // Copy the arguments for outlined function
     if (HasShareds) {
-      Value *Shareds = StaleCI->getArgOperand(0);
+      Value *Shareds = StaleCI->getArgOperand(1);
       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
       Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
@@ -1697,10 +1678,9 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
     if (IfCondition) {
       // `SplitBlockAndInsertIfThenElse` requires the block to have a
       // terminator.
-      BasicBlock *NewBasicBlock =
-          splitBB(Builder, /*CreateBranch=*/true, "if.end");
+      splitBB(Builder, /*CreateBranch=*/true, "if.end");
       Instruction *IfTerminator =
-          NewBasicBlock->getSinglePredecessor()->getTerminator();
+          Builder.GetInsertPoint()->getParent()->getTerminator();
       Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
       Builder.SetInsertPoint(IfTerminator);
       SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
@@ -1711,10 +1691,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
       Function *TaskCompleteFn =
           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
+      CallInst *CI = nullptr;
       if (HasShareds)
-        Builder.CreateCall(WrapperFunc, {ThreadID, TaskData});
+        CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
       else
-        Builder.CreateCall(WrapperFunc, {ThreadID});
+        CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
+      CI->setDebugLoc(StaleCI->getDebugLoc());
       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
       Builder.SetInsertPoint(ThenTI);
     }
@@ -1736,18 +1718,28 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
 
     StaleCI->eraseFromParent();
 
-    // Emit the body for wrapper function
-    BasicBlock *WrapperEntryBB =
-        BasicBlock::Create(M.getContext(), "", WrapperFunc);
-    Builder.SetInsertPoint(WrapperEntryBB);
+    Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
     if (HasShareds) {
-      llvm::Value *Shareds =
-          Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1));
-      Builder.CreateCall(&OutlinedFn, {Shareds});
-    } else {
-      Builder.CreateCall(&OutlinedFn);
+      LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
+      OutlinedFn.getArg(1)->replaceUsesWithIf(
+          Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
+    }
+
+    // Replace kmpc_global_thread_num() calls with the global thread id
+    // argument.
+    OutlinedFn.getArg(0)->setName("global.tid");
+    FunctionCallee TIDRTLFn =
+        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
+    for (Instruction &Inst : instructions(OutlinedFn)) {
+      CallInst *CI = dyn_cast<CallInst>(&Inst);
+      if (!CI)
+        continue;
+      if (CI->getCalledFunction() == TIDRTLFn.getCallee())
+        CI->replaceAllUsesWith(OutlinedFn.getArg(0));
     }
-    Builder.CreateRet(Builder.getInt32(0));
+
+    for (Instruction *I : ToBeDeleted)
+      I->eraseFromParent();
   };
 
   addOutlineInfo(std::move(OI));
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fd524f6067ee0ea..643b34270c01693 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5486,25 +5486,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTask) {
             24); // 64-bit pointer + 128-bit integer
 
   // Verify Wrapper function
-  Function *WrapperFunc =
+  Function *OutlinedFn =
       dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
-  ASSERT_NE(WrapperFunc, nullptr);
+  ASSERT_NE(OutlinedFn, nullptr);
 
-  LoadInst *SharedsLoad = dyn_cast<LoadInst>(WrapperFunc->begin()->begin());
+  LoadInst *SharedsLoad = dyn_cast<LoadInst>(OutlinedFn->begin()->begin());
   ASSERT_NE(SharedsLoad, nullptr);
-  EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1));
-
-  EXPECT_FALSE(WrapperFunc->isDeclaration());
-  CallInst *OutlinedFnCall =
-      dyn_cast<CallInst>(++WrapperFunc->begin()->begin());
-  ASSERT_NE(OutlinedFnCall, nullptr);
-  EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty());
-  EXPECT_EQ(OutlinedFnCall->getArgOperand(0),
-            WrapperFunc->getArg(1)->uses().begin()->getUser());
+  EXPECT_EQ(SharedsLoad->getPointerOperand(), OutlinedFn->getArg(1));
+
+  EXPECT_FALSE(OutlinedFn->isDeclaration());
+  EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getInt32Ty());
+
+  // Verify that the data argument is used only once, and that too in the load
+  // instruction that is then used for accessing shared data.
+  Value *DataPtr = OutlinedFn->getArg(1);
+  EXPECT_EQ(DataPtr->getNumUses(), 1);
+  EXPECT_TRUE(isa<LoadInst>(DataPtr->uses().begin()->getUser()));
+  Value *Data = DataPtr->uses().begin()->getUser();
+  EXPECT_TRUE(all_of(Data->uses(), [](Use &U) {
+    return isa<GetElementPtrInst>(U.getUser());
+  }));
 
   // Verify the presence of `trunc` and `icmp` instructions in Outlined function
-  Function *OutlinedFn = OutlinedFnCall->getCalledFunction();
-  ASSERT_NE(OutlinedFn, nullptr);
   EXPECT_TRUE(any_of(instructions(OutlinedFn),
                      [](Instruction &inst) { return isa<TruncInst>(&inst); }));
   EXPECT_TRUE(any_of(instructions(OutlinedFn),
@@ -5547,6 +5550,14 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) {
   Builder.CreateRetVoid();
 
   EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  // Check that the outlined function has only one argument.
+  CallInst *TaskAllocCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+          ->user_back());
+  Function *OutlinedFn = dyn_cast<Function>(TaskAllocCall->getArgOperand(5));
+  ASSERT_NE(OutlinedFn, nullptr);
+  ASSERT_EQ(OutlinedFn->arg_size(), 1);
 }
 
 TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
@@ -5658,8 +5669,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
   F->setName("func");
   IRBuilder<> Builder(BB);
   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
-  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
   Builder.SetInsertPoint(BodyBB);
   Value *Final = Builder.CreateICmp(
       CmpInst::Predicate::ICMP_EQ, F->getArg(0),
@@ -5711,8 +5722,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
   F->setName("func");
   IRBuilder<> Builder(BB);
   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
-  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
   Builder.SetInsertPoint(BodyBB);
   Value *IfCondition = Builder.CreateICmp(
       CmpInst::Predicate::ICMP_EQ, F->getArg(0),
@@ -5758,15 +5769,16 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
           ->user_back());
   ASSERT_NE(TaskBeginIfCall, nullptr);
   ASSERT_NE(TaskCompleteCall, nullptr);
-  Function *WrapperFunc =
+  Function *OulinedFn =
       dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
-  ASSERT_NE(WrapperFunc, nullptr);
-  CallInst *WrapperFuncCall = dyn_cast<CallInst>(WrapperFunc->user_back());
-  ASSERT_NE(WrapperFuncCall, nullptr);
+  ASSERT_NE(OulinedFn, nullptr);
+  CallInst *OulinedFnCall = dyn_cast<CallInst>(OulinedFn->user_back());
+  ASSERT_NE(OulinedFnCall, nullptr);
   EXPECT_EQ(TaskBeginIfCall->getParent(),
             IfConditionBranchInst->getSuccessor(1));
-  EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall);
-  EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall);
+
+  EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), OulinedFnCall);
+  EXPECT_EQ(OulinedFnCall->getNextNonDebugInstruction(), TaskCompleteCall);
 }
 
 TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 28b0113a19d61b8..2cd561cb021075f 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2209,7 +2209,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
   // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
   // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
   // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
-  // CHECK-SAME:  i64 0, ptr @[[wrapper_fn:.+]])
+  // CHECK-SAME:  i64 0, ptr @[[outlined_fn:.+]])
   // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
   omp.task {
     %n = llvm.mlir.constant(1 : i64) : i64
@@ -2222,7 +2222,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
   llvm.return
 }
 
-// CHECK: define internal void @[[outlined_fn:.+]]()
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]])
 // CHECK: task.alloca{{.*}}:
 // CHECK:   br label %[[task_body:[^, ]+]]
 // CHECK: [[task_body]]:
@@ -2236,12 +2236,6 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 // CHECK: [[exit_stub]]:
 // CHECK:   ret void
 
-
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
-// CHECK:   call void @[[outlined_fn]]()
-// CHECK:   ret i32 0
-// CHECK: }
-
 // -----
 
 // CHECK-LABEL: define void @omp_task_with_deps
@@ -2259,7 +2253,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
   // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
   // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
   // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
-  // CHECK-SAME:  i64 0, ptr @[[wrapper_fn:.+]])
+  // CHECK-SAME:  i64 0, ptr @[[outlined_fn:.+]])
   // CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}})
   omp.task depend(taskdependin -> %zaddr : !llvm.ptr<i32>) {
     %n = llvm.mlir.constant(1 : i64) : i64
@@ -2272,7 +2266,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
   llvm.return
 }
 
-// CHECK: define internal void @[[outlined_fn:.+]]()
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]])
 // CHECK: task.alloca{{.*}}:
 // CHECK:   br label %[[task_body:[^, ]+]]
 // CHECK: [[task_body]]:
@@ -2286,11 +2280,6 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
 // CHECK: [[exit_stub]]:
 // CHECK:   ret void
 
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
-// CHECK:   call void @[[outlined_fn]]()
-// CHECK:   ret i32 0
-// CHECK: }
-
 // -----
 
 // CHECK-LABEL: define void @omp_task
@@ -2304,7 +2293,7 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
     // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
     // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
     // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16,
-    // CHECK-SAME: ptr @[[wrapper_fn:.+]])
+    // CHECK-SAME: ptr @[[outlined_fn:.+]])
     // CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]]
     // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false)
     // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
@@ -2321,8 +2310,9 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
   }
 }
 
-// CHECK: define internal void @[[outlined_fn:.+]](ptr %[[task_data:.+]])
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]], ptr %[[task_data:.+]])
 // CHECK: task.alloca{{.*}}:
+// CHECK:   %[[shareds:.+]] = load ptr, ptr %[[task_data]]
 // CHECK:   br label %[[task_body:[^, ]+]]
 // CHECK: [[task_body]]:
 // CHECK:   br label %[[task_region:[^, ]+]]
@@ -2333,13 +2323,6 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
 // CHECK: [[exit_stub]]:
 // CHECK:   ret void
 
-
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) {
-// CHECK:  %[[shareds:.+]] = load ptr, ptr %1, align 8
-// CHECK:   call void @[[outlined_fn]](ptr %[[shareds]])
-// CHECK:   ret i32 0
-// CHECK: }
-
 // -----
 
 llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) {
@@ -2355,14 +2338,12 @@ llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) {
 }
 
 // CHECK-LABEL: @par_task_
-// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @par_task_..omp_par.wrapper)
+// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @[[task_outlined_fn:.+]])
 // CHECK: call i32 @__kmpc_omp_task({{.*}}, ptr %[[TASK_ALLOC]])
-// CHECK-LABEL: define internal void @par_task_..omp_par
+// CHECK: define internal void @[[task_outlined_fn]]
 // CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
-// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @par_task_..omp_par..omp_par, ptr %[[ARG_ALLOC]])
-// CHECK: define internal void @par_task_..omp_par..omp_par
-// CHECK: define i32 @par_task_..omp_par.wrapper
-// CHECK: call void @par_task_..omp_par
+// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[ARG_ALLOC]])
+// CHECK: define internal void @[[parallel_outlined_fn]]
 // -----
 
 llvm.func @foo() -> ()
@@ -2432,7 +2413,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 // CHECK:         br label %[[codeRepl:[^,]+]]
 // CHECK:       [[codeRepl]]:
 // CHECK:         %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK:         %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
+// CHECK:         %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @[[outlined_task_fn:.+]])
 // CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]])
 // CHECK:         br label %[[task_exit:[^,]+]]
 // CHECK:       [[task_exit]]:
@@ -2445,7 +2426,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 // CHECK:         %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2
 // CHECK:         store ptr %[[zaddr]], ptr %[[gep3]], align 8
 // CHECK:         %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK:         %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper)
+// CHECK:         %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @[[outlined_task_fn:.+]])
 // CHECK:         %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]]
 // CHECK:         call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[shareds]], ptr align 1 %[[structArg]], i64 16, i1 false)
 // CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]])
@@ -2617,7 +2598,7 @@ llvm.func @omp_task_final(%boolexpr: i1) {
 // CHECK:         %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
 // CHECK:         %[[final_flag:.+]] = select i1 %[[boolexpr]], i32 2, i32 0
 // CHECK:         %[[task_flags:.+]] = or i32 %[[final_flag]], 1
-// CHECK:         %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @omp_task_final..omp_par.wrapper)
+// CHECK:         %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @[[task_outlined_fn:.+]])
 // CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
 // CHECK:         br label %[[task_exit:[^,]+]]
 // CHECK:       [[task_exit]]:
@@ -2648,14 +2629,14 @@ llvm.func @omp_task_if(%boolexpr: i1) {
 // CHECK:         br label %[[codeRepl:[^,]+]]
 // CHECK:       [[codeRepl]]:
 // CHECK:         %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK:         %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @omp_task_if..omp_par.wrapper)
+// CHECK:         %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @[[task_outlined_fn:.+]])
 // CHECK:         br i1 %[[boolexpr]], label %[[true_label:[^,]+]], label %[[false_label:[^,]+]]
 // CHECK:       [[true_label]]:
 // CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
 // CHECK:         br label %[[if_else_exit:[^,]+]]
 // CHECK:       [[false_label:[^,]+]]:                                                ; preds = %codeRepl
 // CHECK:         call void @__kmpc_omp_task_begin_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
-// CHECK:         %{{.+}} = call i32 @omp_task_if..omp_par.wrapper(i32 %[[omp_global_thread_num]])
+// CHECK:         call void @[[task_outlined_fn]](i32 %[[omp_global_thread_num]])
 // CHECK:         call void @__kmpc_omp_task_complete_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
 // CHECK:         br label %[[if_else_exit]]
 // CHECK:       [[if_else_exit]]:

>From a1a9438b5e00170030b419a7736053422745cbc6 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Mon, 2 Oct 2023 09:22:30 -0500
Subject: [PATCH 2/2] Remove outlining for teams too.

---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 178 +++++++++---------
 .../Frontend/OpenMPIRBuilderTest.cpp          |  22 +--
 2 files changed, 95 insertions(+), 105 deletions(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 54012b488c6b671..a5a73bcc10c48e3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -341,6 +341,44 @@ BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
 }
 
+// This function creates a fake integer value and a fake use for the integer
+// value. It returns the fake value created. This is useful in modeling the
+// extra arguments to the outlined functions.
+Value *createFakeIntVal(IRBuilder<> &Builder,
+                        OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
+                        std::stack<Instruction *> &ToBeDeleted,
+                        OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
+                        const Twine &Name = "", bool AsPtr = true) {
+  Builder.restoreIP(OuterAllocaIP);
+  Instruction *FakeVal;
+  AllocaInst *FakeValAddr =
+      Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
+  ToBeDeleted.push(FakeValAddr);
+
+  if (AsPtr)
+    FakeVal = FakeValAddr;
+  else {
+    FakeVal =
+        Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
+    ToBeDeleted.push(FakeVal);
+  }
+
+  // We only need TIDAddr and ZeroAddr for modeling purposes to get the
+  // associated arguments in the outlined function, so we delete them later.
+
+  // Fake use of TID
+  Builder.restoreIP(InnerAllocaIP);
+  Instruction *UseFakeVal;
+  if (AsPtr)
+    UseFakeVal =
+        Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
+  else
+    UseFakeVal =
+        cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
+  ToBeDeleted.push(UseFakeVal);
+  return FakeVal;
+}
+
 //===----------------------------------------------------------------------===//
 // OpenMPIRBuilderConfig
 //===----------------------------------------------------------------------===//
@@ -1497,13 +1535,6 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
                             bool Tied, Value *Final, Value *IfCondition,
                             SmallVector<DependData> Dependencies) {
-  // We create a temporary i32 value that will represent the global tid after
-  // outlining.
-  SmallVector<Instruction *, 4> ToBeDeleted;
-  Builder.restoreIP(AllocaIP);
-  AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
-  LoadInst *TID = Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use");
-  ToBeDeleted.append({TID, TIDAddr});
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -1532,19 +1563,24 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
   BasicBlock *TaskAllocaBB =
       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
 
-  // Fake use of TID
-  Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
-  BinaryOperator *AddInst =
-      dyn_cast<BinaryOperator>(Builder.CreateAdd(TID, Builder.getInt32(10)));
-  ToBeDeleted.push_back(AddInst);
+  InsertPointTy TaskAllocaIP =
+      InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
+  InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
+  BodyGenCB(TaskAllocaIP, TaskBodyIP);
+  Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
 
   OutlineInfo OI;
   OI.EntryBB = TaskAllocaBB;
   OI.OuterAllocaBB = AllocaIP.getBlock();
   OI.ExitBB = TaskExitBB;
-  OI.ExcludeArgsFromAggregate = {TID};
+
+  // Add the thread ID argument.
+  std::stack<Instruction *> ToBeDeleted;
+  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+      Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
+
   OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
-                      TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) {
+                      TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
     // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.getNumUses() == 1 &&
            "there must be a single user for the outlined function");
@@ -1670,7 +1706,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
     //    br label %exit
     //  else:
     //    call @__kmpc_omp_task_begin_if0(...)
-    //    call @wrapper_fn(...)
+    //    call @outlined_fn(...)
     //    call @__kmpc_omp_task_complete_if0(...)
     //    br label %exit
     //  exit:
@@ -1725,31 +1761,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
           Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
     }
 
-    // Replace kmpc_global_thread_num() calls with the global thread id
-    // argument.
-    OutlinedFn.getArg(0)->setName("global.tid");
-    FunctionCallee TIDRTLFn =
-        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
-    for (Instruction &Inst : instructions(OutlinedFn)) {
-      CallInst *CI = dyn_cast<CallInst>(&Inst);
-      if (!CI)
-        continue;
-      if (CI->getCalledFunction() == TIDRTLFn.getCallee())
-        CI->replaceAllUsesWith(OutlinedFn.getArg(0));
+    while (!ToBeDeleted.empty()) {
+      ToBeDeleted.top()->eraseFromParent();
+      ToBeDeleted.pop();
     }
-
-    for (Instruction *I : ToBeDeleted)
-      I->eraseFromParent();
   };
 
   addOutlineInfo(std::move(OI));
 
-  InsertPointTy TaskAllocaIP =
-      InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
-  InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
-  BodyGenCB(TaskAllocaIP, TaskBodyIP);
-  Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
-
   return Builder.saveIP();
 }
 
@@ -5740,6 +5759,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
     BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
     Builder.SetInsertPoint(BodyBB, BodyBB->begin());
   }
+  InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
 
   // The current basic block is split into four basic blocks. After outlining,
   // they will be mapped as follows:
@@ -5763,84 +5783,62 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   BasicBlock *AllocaBB =
       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
 
+  // Generate the body of teams.
+  InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
+  InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
+  BodyGenCB(AllocaIP, CodeGenIP);
+
   OutlineInfo OI;
   OI.EntryBB = AllocaBB;
   OI.ExitBB = ExitBB;
   OI.OuterAllocaBB = &OuterAllocaBB;
-  OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
-    // The input IR here looks like the following-
-    // ```
-    // func @current_fn() {
-    //   outlined_fn(%args)
-    // }
-    // func @outlined_fn(%args) { ... }
-    // ```
-    //
-    // This is changed to the following-
-    //
-    // ```
-    // func @current_fn() {
-    //   runtime_call(..., wrapper_fn, ...)
-    // }
-    // func @wrapper_fn(..., %args) {
-    //   outlined_fn(%args)
-    // }
-    // func @outlined_fn(%args) { ... }
-    // ```
 
+  // Insert fake values for global tid and bound tid.
+  std::stack<Instruction *> ToBeDeleted;
+  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+      Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
+  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+      Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
+
+  OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable {
     // The stale call instruction will be replaced with a new call instruction
-    // for runtime call with a wrapper function.
+    // for runtime call with the outlined function.
 
     assert(OutlinedFn.getNumUses() == 1 &&
            "there must be a single user for the outlined function");
     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+    ToBeDeleted.push(StaleCI);
+
+    assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
+           "Outlined function must have two or three arguments only");
 
-    // Create the wrapper function.
-    SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
-    for (auto &Arg : OutlinedFn.args())
-      WrapperArgTys.push_back(Arg.getType());
-    FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
-        (Twine(OutlinedFn.getName()) + ".teams").str(),
-        FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
-    Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
-    WrapperFunc->getArg(0)->setName("global_tid");
-    WrapperFunc->getArg(1)->setName("bound_tid");
-    if (WrapperFunc->arg_size() > 2)
-      WrapperFunc->getArg(2)->setName("data");
-
-    // Emit the body of the wrapper function - just a call to outlined function
-    // and return statement.
-    BasicBlock *WrapperEntryBB =
-        BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
-    Builder.SetInsertPoint(WrapperEntryBB);
-    SmallVector<Value *> Args;
-    for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++)
-      Args.push_back(WrapperFunc->getArg(ArgIndex));
-    Builder.CreateCall(&OutlinedFn, Args);
-    Builder.CreateRetVoid();
-
-    OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+    bool HasShared = OutlinedFn.arg_size() == 3;
+
+    OutlinedFn.getArg(0)->setName("global.tid.ptr");
+    OutlinedFn.getArg(1)->setName("bound.tid.ptr");
+    if (HasShared)
+      OutlinedFn.getArg(2)->setName("data");
 
     // Call to the runtime function for teams in the current function.
     assert(StaleCI && "Error while outlining - no CallInst user found for the "
                       "outlined function.");
     Builder.SetInsertPoint(StaleCI);
-    Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
-    for (Use &Arg : StaleCI->args())
-      Args.push_back(Arg);
+    SmallVector<Value *> Args = {Ident, Builder.getInt32(StaleCI->arg_size()),
+                                 &OutlinedFn};
+    if (HasShared)
+      Args.push_back(StaleCI->getArgOperand(2));
     Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
                            omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
                        Args);
-    StaleCI->eraseFromParent();
+
+    while (!ToBeDeleted.empty()) {
+      ToBeDeleted.top()->eraseFromParent();
+      ToBeDeleted.pop();
+    }
   };
 
   addOutlineInfo(std::move(OI));
 
-  // Generate the body of teams.
-  InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
-  InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
-  BodyGenCB(AllocaIP, CodeGenIP);
-
   Builder.SetInsertPoint(ExitBB, ExitBB->begin());
 
   return Builder.saveIP();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 643b34270c01693..c4b0389c89c7c60 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4057,25 +4057,17 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
   ASSERT_NE(SrcSrc, nullptr);
 
   // Verify the outlined function signature.
-  Function *WrapperFn =
+  Function *OutlinedFn =
       dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts());
-  ASSERT_NE(WrapperFn, nullptr);
-  EXPECT_FALSE(WrapperFn->isDeclaration());
-  EXPECT_TRUE(WrapperFn->arg_size() >= 3);
-  EXPECT_EQ(WrapperFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
-  EXPECT_EQ(WrapperFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
-  EXPECT_EQ(WrapperFn->getArg(2)->getType(),
+  ASSERT_NE(OutlinedFn, nullptr);
+  EXPECT_FALSE(OutlinedFn->isDeclaration());
+  EXPECT_TRUE(OutlinedFn->arg_size() >= 3);
+  EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
+  EXPECT_EQ(OutlinedFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
+  EXPECT_EQ(OutlinedFn->getArg(2)->getType(),
             Builder.getPtrTy()); // captured args
 
   // Check for TruncInst and ICmpInst in the outlined function.
-  inst_range Instructions = instructions(WrapperFn);
-  auto OutlinedFnInst = find_if(
-      Instructions, [](Instruction &Inst) { return isa<CallInst>(&Inst); });
-  ASSERT_NE(OutlinedFnInst, Instructions.end());
-  CallInst *OutlinedFnCI = dyn_cast<CallInst>(&*OutlinedFnInst);
-  ASSERT_NE(OutlinedFnCI, nullptr);
-  Function *OutlinedFn = OutlinedFnCI->getCalledFunction();
-
   EXPECT_TRUE(any_of(instructions(OutlinedFn),
                      [](Instruction &inst) { return isa<TruncInst>(&inst); }));
   EXPECT_TRUE(any_of(instructions(OutlinedFn),



More information about the libc-commits mailing list