[Mlir-commits] [mlir] [OpenMPIRBuilder] Remove wrapper function in `createTask` (PR #67723)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 2 07:23:25 PDT 2023
https://github.com/shraiysh updated https://github.com/llvm/llvm-project/pull/67723
>From 6aabc3c10ea2d587120b74966b7ce96f9b8167af Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Thu, 28 Sep 2023 13:35:07 -0500
Subject: [PATCH 1/2] [OpenMPIRBuilder] Remove wrapper function in `createTask`
This patch removes the wrapper function in
`OpenMPIRBuilder::createTask`. The outlined function is directly of the
form that is expected by the runtime library calls. This also fixes the
global thread ID argument, which should be used whenever
`kmpc_global_thread_num()` is called inside the outlined function.
---
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 106 ++++++++----------
.../Frontend/OpenMPIRBuilderTest.cpp | 56 +++++----
mlir/test/Target/LLVMIR/openmp-llvm.mlir | 51 +++------
3 files changed, 99 insertions(+), 114 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 9c70d384e55db2b..54012b488c6b671 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -35,6 +35,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
@@ -1496,6 +1497,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied, Value *Final, Value *IfCondition,
SmallVector<DependData> Dependencies) {
+ // We create a temporary i32 value that will represent the global tid after
+ // outlining.
+ SmallVector<Instruction *, 4> ToBeDeleted;
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
+ LoadInst *TID = Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use");
+ ToBeDeleted.append({TID, TIDAddr});
+
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -1523,41 +1532,27 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
BasicBlock *TaskAllocaBB =
splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
+ // Fake use of TID
+ Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
+ BinaryOperator *AddInst =
+ dyn_cast<BinaryOperator>(Builder.CreateAdd(TID, Builder.getInt32(10)));
+ ToBeDeleted.push_back(AddInst);
+
OutlineInfo OI;
OI.EntryBB = TaskAllocaBB;
OI.OuterAllocaBB = AllocaIP.getBlock();
OI.ExitBB = TaskExitBB;
- OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition,
- Dependencies](Function &OutlinedFn) {
- // The input IR here looks like the following-
- // ```
- // func @current_fn() {
- // outlined_fn(%args)
- // }
- // func @outlined_fn(%args) { ... }
- // ```
- //
- // This is changed to the following-
- //
- // ```
- // func @current_fn() {
- // runtime_call(..., wrapper_fn, ...)
- // }
- // func @wrapper_fn(..., %args) {
- // outlined_fn(%args)
- // }
- // func @outlined_fn(%args) { ... }
- // ```
-
- // The stale call instruction will be replaced with a new call instruction
- // for runtime call with a wrapper function.
+ OI.ExcludeArgsFromAggregate = {TID};
+ OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
+ TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) {
+ // Replace the Stale CI by appropriate RTL function call.
assert(OutlinedFn.getNumUses() == 1 &&
"there must be a single user for the outlined function");
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
// HasShareds is true if any variables are captured in the outlined region,
// false otherwise.
- bool HasShareds = StaleCI->arg_size() > 0;
+ bool HasShareds = StaleCI->arg_size() > 1;
Builder.SetInsertPoint(StaleCI);
// Gather the arguments for emitting the runtime call for
@@ -1595,7 +1590,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
Value *SharedsSize = Builder.getInt64(0);
if (HasShareds) {
AllocaInst *ArgStructAlloca =
- dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
+ dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
assert(ArgStructAlloca &&
"Unable to find the alloca instruction corresponding to arguments "
"for extracted function");
@@ -1606,31 +1601,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
SharedsSize =
Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
}
-
- // Argument - task_entry (the wrapper function)
- // If the outlined function has some captured variables (i.e. HasShareds is
- // true), then the wrapper function will have an additional argument (the
- // struct containing captured variables). Otherwise, no such argument will
- // be present.
- SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
- if (HasShareds)
- WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
- FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
- (Twine(OutlinedFn.getName()) + ".wrapper").str(),
- FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false));
- Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
-
// Emit the @__kmpc_omp_task_alloc runtime call
// The runtime call returns a pointer to an area where the task captured
// variables must be copied before the task is run (TaskData)
CallInst *TaskData = Builder.CreateCall(
TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
/*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
- /*task_func=*/WrapperFunc});
+ /*task_func=*/&OutlinedFn});
// Copy the arguments for outlined function
if (HasShareds) {
- Value *Shareds = StaleCI->getArgOperand(0);
+ Value *Shareds = StaleCI->getArgOperand(1);
Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
@@ -1697,10 +1678,9 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
if (IfCondition) {
// `SplitBlockAndInsertIfThenElse` requires the block to have a
// terminator.
- BasicBlock *NewBasicBlock =
- splitBB(Builder, /*CreateBranch=*/true, "if.end");
+ splitBB(Builder, /*CreateBranch=*/true, "if.end");
Instruction *IfTerminator =
- NewBasicBlock->getSinglePredecessor()->getTerminator();
+ Builder.GetInsertPoint()->getParent()->getTerminator();
Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
Builder.SetInsertPoint(IfTerminator);
SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
@@ -1711,10 +1691,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
Function *TaskCompleteFn =
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
+ CallInst *CI = nullptr;
if (HasShareds)
- Builder.CreateCall(WrapperFunc, {ThreadID, TaskData});
+ CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
else
- Builder.CreateCall(WrapperFunc, {ThreadID});
+ CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
+ CI->setDebugLoc(StaleCI->getDebugLoc());
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
Builder.SetInsertPoint(ThenTI);
}
@@ -1736,18 +1718,28 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
StaleCI->eraseFromParent();
- // Emit the body for wrapper function
- BasicBlock *WrapperEntryBB =
- BasicBlock::Create(M.getContext(), "", WrapperFunc);
- Builder.SetInsertPoint(WrapperEntryBB);
+ Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
if (HasShareds) {
- llvm::Value *Shareds =
- Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1));
- Builder.CreateCall(&OutlinedFn, {Shareds});
- } else {
- Builder.CreateCall(&OutlinedFn);
+ LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
+ OutlinedFn.getArg(1)->replaceUsesWithIf(
+ Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
+ }
+
+ // Replace kmpc_global_thread_num() calls with the global thread id
+ // argument.
+ OutlinedFn.getArg(0)->setName("global.tid");
+ FunctionCallee TIDRTLFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
+ for (Instruction &Inst : instructions(OutlinedFn)) {
+ CallInst *CI = dyn_cast<CallInst>(&Inst);
+ if (!CI)
+ continue;
+ if (CI->getCalledFunction() == TIDRTLFn.getCallee())
+ CI->replaceAllUsesWith(OutlinedFn.getArg(0));
}
- Builder.CreateRet(Builder.getInt32(0));
+
+ for (Instruction *I : ToBeDeleted)
+ I->eraseFromParent();
};
addOutlineInfo(std::move(OI));
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fd524f6067ee0ea..643b34270c01693 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5486,25 +5486,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTask) {
24); // 64-bit pointer + 128-bit integer
// Verify Wrapper function
- Function *WrapperFunc =
+ Function *OutlinedFn =
dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
- ASSERT_NE(WrapperFunc, nullptr);
+ ASSERT_NE(OutlinedFn, nullptr);
- LoadInst *SharedsLoad = dyn_cast<LoadInst>(WrapperFunc->begin()->begin());
+ LoadInst *SharedsLoad = dyn_cast<LoadInst>(OutlinedFn->begin()->begin());
ASSERT_NE(SharedsLoad, nullptr);
- EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1));
-
- EXPECT_FALSE(WrapperFunc->isDeclaration());
- CallInst *OutlinedFnCall =
- dyn_cast<CallInst>(++WrapperFunc->begin()->begin());
- ASSERT_NE(OutlinedFnCall, nullptr);
- EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty());
- EXPECT_EQ(OutlinedFnCall->getArgOperand(0),
- WrapperFunc->getArg(1)->uses().begin()->getUser());
+ EXPECT_EQ(SharedsLoad->getPointerOperand(), OutlinedFn->getArg(1));
+
+ EXPECT_FALSE(OutlinedFn->isDeclaration());
+ EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getInt32Ty());
+
+ // Verify that the data argument is used only once, and that too in the load
+ // instruction that is then used for accessing shared data.
+ Value *DataPtr = OutlinedFn->getArg(1);
+ EXPECT_EQ(DataPtr->getNumUses(), 1);
+ EXPECT_TRUE(isa<LoadInst>(DataPtr->uses().begin()->getUser()));
+ Value *Data = DataPtr->uses().begin()->getUser();
+ EXPECT_TRUE(all_of(Data->uses(), [](Use &U) {
+ return isa<GetElementPtrInst>(U.getUser());
+ }));
// Verify the presence of `trunc` and `icmp` instructions in Outlined function
- Function *OutlinedFn = OutlinedFnCall->getCalledFunction();
- ASSERT_NE(OutlinedFn, nullptr);
EXPECT_TRUE(any_of(instructions(OutlinedFn),
[](Instruction &inst) { return isa<TruncInst>(&inst); }));
EXPECT_TRUE(any_of(instructions(OutlinedFn),
@@ -5547,6 +5550,14 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) {
Builder.CreateRetVoid();
EXPECT_FALSE(verifyModule(*M, &errs()));
+
+ // Check that the outlined function has only one argument.
+ CallInst *TaskAllocCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+ ->user_back());
+ Function *OutlinedFn = dyn_cast<Function>(TaskAllocCall->getArgOperand(5));
+ ASSERT_NE(OutlinedFn, nullptr);
+ ASSERT_EQ(OutlinedFn->arg_size(), 1);
}
TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
@@ -5658,8 +5669,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
F->setName("func");
IRBuilder<> Builder(BB);
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
- IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
Builder.SetInsertPoint(BodyBB);
Value *Final = Builder.CreateICmp(
CmpInst::Predicate::ICMP_EQ, F->getArg(0),
@@ -5711,8 +5722,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
F->setName("func");
IRBuilder<> Builder(BB);
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
- IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
Builder.SetInsertPoint(BodyBB);
Value *IfCondition = Builder.CreateICmp(
CmpInst::Predicate::ICMP_EQ, F->getArg(0),
@@ -5758,15 +5769,16 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
->user_back());
ASSERT_NE(TaskBeginIfCall, nullptr);
ASSERT_NE(TaskCompleteCall, nullptr);
- Function *WrapperFunc =
+ Function *OulinedFn =
dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
- ASSERT_NE(WrapperFunc, nullptr);
- CallInst *WrapperFuncCall = dyn_cast<CallInst>(WrapperFunc->user_back());
- ASSERT_NE(WrapperFuncCall, nullptr);
+ ASSERT_NE(OulinedFn, nullptr);
+ CallInst *OulinedFnCall = dyn_cast<CallInst>(OulinedFn->user_back());
+ ASSERT_NE(OulinedFnCall, nullptr);
EXPECT_EQ(TaskBeginIfCall->getParent(),
IfConditionBranchInst->getSuccessor(1));
- EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall);
- EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall);
+
+ EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), OulinedFnCall);
+ EXPECT_EQ(OulinedFnCall->getNextNonDebugInstruction(), TaskCompleteCall);
}
TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 28b0113a19d61b8..2cd561cb021075f 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2209,7 +2209,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
- // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]])
+ // CHECK-SAME: i64 0, ptr @[[outlined_fn:.+]])
// CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
omp.task {
%n = llvm.mlir.constant(1 : i64) : i64
@@ -2222,7 +2222,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
llvm.return
}
-// CHECK: define internal void @[[outlined_fn:.+]]()
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]])
// CHECK: task.alloca{{.*}}:
// CHECK: br label %[[task_body:[^, ]+]]
// CHECK: [[task_body]]:
@@ -2236,12 +2236,6 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: [[exit_stub]]:
// CHECK: ret void
-
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
-// CHECK: call void @[[outlined_fn]]()
-// CHECK: ret i32 0
-// CHECK: }
-
// -----
// CHECK-LABEL: define void @omp_task_with_deps
@@ -2259,7 +2253,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
- // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]])
+ // CHECK-SAME: i64 0, ptr @[[outlined_fn:.+]])
// CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}})
omp.task depend(taskdependin -> %zaddr : !llvm.ptr<i32>) {
%n = llvm.mlir.constant(1 : i64) : i64
@@ -2272,7 +2266,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
llvm.return
}
-// CHECK: define internal void @[[outlined_fn:.+]]()
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]])
// CHECK: task.alloca{{.*}}:
// CHECK: br label %[[task_body:[^, ]+]]
// CHECK: [[task_body]]:
@@ -2286,11 +2280,6 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
// CHECK: [[exit_stub]]:
// CHECK: ret void
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
-// CHECK: call void @[[outlined_fn]]()
-// CHECK: ret i32 0
-// CHECK: }
-
// -----
// CHECK-LABEL: define void @omp_task
@@ -2304,7 +2293,7 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16,
- // CHECK-SAME: ptr @[[wrapper_fn:.+]])
+ // CHECK-SAME: ptr @[[outlined_fn:.+]])
// CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]]
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false)
// CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
@@ -2321,8 +2310,9 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
}
}
-// CHECK: define internal void @[[outlined_fn:.+]](ptr %[[task_data:.+]])
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]], ptr %[[task_data:.+]])
// CHECK: task.alloca{{.*}}:
+// CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]]
// CHECK: br label %[[task_body:[^, ]+]]
// CHECK: [[task_body]]:
// CHECK: br label %[[task_region:[^, ]+]]
@@ -2333,13 +2323,6 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
// CHECK: [[exit_stub]]:
// CHECK: ret void
-
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) {
-// CHECK: %[[shareds:.+]] = load ptr, ptr %1, align 8
-// CHECK: call void @[[outlined_fn]](ptr %[[shareds]])
-// CHECK: ret i32 0
-// CHECK: }
-
// -----
llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) {
@@ -2355,14 +2338,12 @@ llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) {
}
// CHECK-LABEL: @par_task_
-// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @par_task_..omp_par.wrapper)
+// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @[[task_outlined_fn:.+]])
// CHECK: call i32 @__kmpc_omp_task({{.*}}, ptr %[[TASK_ALLOC]])
-// CHECK-LABEL: define internal void @par_task_..omp_par
+// CHECK: define internal void @[[task_outlined_fn]]
// CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
-// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @par_task_..omp_par..omp_par, ptr %[[ARG_ALLOC]])
-// CHECK: define internal void @par_task_..omp_par..omp_par
-// CHECK: define i32 @par_task_..omp_par.wrapper
-// CHECK: call void @par_task_..omp_par
+// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[ARG_ALLOC]])
+// CHECK: define internal void @[[parallel_outlined_fn]]
// -----
llvm.func @foo() -> ()
@@ -2432,7 +2413,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: br label %[[codeRepl:[^,]+]]
// CHECK: [[codeRepl]]:
// CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
+// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @[[outlined_task_fn:.+]])
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]])
// CHECK: br label %[[task_exit:[^,]+]]
// CHECK: [[task_exit]]:
@@ -2445,7 +2426,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2
// CHECK: store ptr %[[zaddr]], ptr %[[gep3]], align 8
// CHECK: %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper)
+// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @[[outlined_task_fn:.+]])
// CHECK: %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]]
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[shareds]], ptr align 1 %[[structArg]], i64 16, i1 false)
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]])
@@ -2617,7 +2598,7 @@ llvm.func @omp_task_final(%boolexpr: i1) {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
// CHECK: %[[final_flag:.+]] = select i1 %[[boolexpr]], i32 2, i32 0
// CHECK: %[[task_flags:.+]] = or i32 %[[final_flag]], 1
-// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @omp_task_final..omp_par.wrapper)
+// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @[[task_outlined_fn:.+]])
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
// CHECK: br label %[[task_exit:[^,]+]]
// CHECK: [[task_exit]]:
@@ -2648,14 +2629,14 @@ llvm.func @omp_task_if(%boolexpr: i1) {
// CHECK: br label %[[codeRepl:[^,]+]]
// CHECK: [[codeRepl]]:
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @omp_task_if..omp_par.wrapper)
+// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @[[task_outlined_fn:.+]])
// CHECK: br i1 %[[boolexpr]], label %[[true_label:[^,]+]], label %[[false_label:[^,]+]]
// CHECK: [[true_label]]:
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
// CHECK: br label %[[if_else_exit:[^,]+]]
// CHECK: [[false_label:[^,]+]]: ; preds = %codeRepl
// CHECK: call void @__kmpc_omp_task_begin_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
-// CHECK: %{{.+}} = call i32 @omp_task_if..omp_par.wrapper(i32 %[[omp_global_thread_num]])
+// CHECK: call void @[[task_outlined_fn]](i32 %[[omp_global_thread_num]])
// CHECK: call void @__kmpc_omp_task_complete_if0(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
// CHECK: br label %[[if_else_exit]]
// CHECK: [[if_else_exit]]:
>From a1a9438b5e00170030b419a7736053422745cbc6 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Mon, 2 Oct 2023 09:22:30 -0500
Subject: [PATCH 2/2] Remove outlining for teams too.
---
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 178 +++++++++---------
.../Frontend/OpenMPIRBuilderTest.cpp | 22 +--
2 files changed, 95 insertions(+), 105 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 54012b488c6b671..a5a73bcc10c48e3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -341,6 +341,44 @@ BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
}
+// This function creates a fake integer value and a fake use for the integer
+// value. It returns the fake value created. This is useful in modeling the
+// extra arguments to the outlined functions.
+Value *createFakeIntVal(IRBuilder<> &Builder,
+ OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
+ std::stack<Instruction *> &ToBeDeleted,
+ OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
+ const Twine &Name = "", bool AsPtr = true) {
+ Builder.restoreIP(OuterAllocaIP);
+ Instruction *FakeVal;
+ AllocaInst *FakeValAddr =
+ Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
+ ToBeDeleted.push(FakeValAddr);
+
+ if (AsPtr)
+ FakeVal = FakeValAddr;
+ else {
+ FakeVal =
+ Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
+ ToBeDeleted.push(FakeVal);
+ }
+
+ // We only need TIDAddr and ZeroAddr for modeling purposes to get the
+ // associated arguments in the outlined function, so we delete them later.
+
+ // Fake use of TID
+ Builder.restoreIP(InnerAllocaIP);
+ Instruction *UseFakeVal;
+ if (AsPtr)
+ UseFakeVal =
+ Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
+ else
+ UseFakeVal =
+ cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
+ ToBeDeleted.push(UseFakeVal);
+ return FakeVal;
+}
+
//===----------------------------------------------------------------------===//
// OpenMPIRBuilderConfig
//===----------------------------------------------------------------------===//
@@ -1497,13 +1535,6 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied, Value *Final, Value *IfCondition,
SmallVector<DependData> Dependencies) {
- // We create a temporary i32 value that will represent the global tid after
- // outlining.
- SmallVector<Instruction *, 4> ToBeDeleted;
- Builder.restoreIP(AllocaIP);
- AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
- LoadInst *TID = Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use");
- ToBeDeleted.append({TID, TIDAddr});
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -1532,19 +1563,24 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
BasicBlock *TaskAllocaBB =
splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
- // Fake use of TID
- Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
- BinaryOperator *AddInst =
- dyn_cast<BinaryOperator>(Builder.CreateAdd(TID, Builder.getInt32(10)));
- ToBeDeleted.push_back(AddInst);
+ InsertPointTy TaskAllocaIP =
+ InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
+ InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
+ BodyGenCB(TaskAllocaIP, TaskBodyIP);
+ Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
OutlineInfo OI;
OI.EntryBB = TaskAllocaBB;
OI.OuterAllocaBB = AllocaIP.getBlock();
OI.ExitBB = TaskExitBB;
- OI.ExcludeArgsFromAggregate = {TID};
+
+ // Add the thread ID argument.
+ std::stack<Instruction *> ToBeDeleted;
+ OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+ Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
+
OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
- TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) {
+ TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
// Replace the Stale CI by appropriate RTL function call.
assert(OutlinedFn.getNumUses() == 1 &&
"there must be a single user for the outlined function");
@@ -1670,7 +1706,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
// br label %exit
// else:
// call @__kmpc_omp_task_begin_if0(...)
- // call @wrapper_fn(...)
+ // call @outlined_fn(...)
// call @__kmpc_omp_task_complete_if0(...)
// br label %exit
// exit:
@@ -1725,31 +1761,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
}
- // Replace kmpc_global_thread_num() calls with the global thread id
- // argument.
- OutlinedFn.getArg(0)->setName("global.tid");
- FunctionCallee TIDRTLFn =
- getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
- for (Instruction &Inst : instructions(OutlinedFn)) {
- CallInst *CI = dyn_cast<CallInst>(&Inst);
- if (!CI)
- continue;
- if (CI->getCalledFunction() == TIDRTLFn.getCallee())
- CI->replaceAllUsesWith(OutlinedFn.getArg(0));
+ while (!ToBeDeleted.empty()) {
+ ToBeDeleted.top()->eraseFromParent();
+ ToBeDeleted.pop();
}
-
- for (Instruction *I : ToBeDeleted)
- I->eraseFromParent();
};
addOutlineInfo(std::move(OI));
- InsertPointTy TaskAllocaIP =
- InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
- InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
- BodyGenCB(TaskAllocaIP, TaskBodyIP);
- Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
-
return Builder.saveIP();
}
@@ -5740,6 +5759,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
Builder.SetInsertPoint(BodyBB, BodyBB->begin());
}
+ InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
// The current basic block is split into four basic blocks. After outlining,
// they will be mapped as follows:
@@ -5763,84 +5783,62 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
BasicBlock *AllocaBB =
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
+ // Generate the body of teams.
+ InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
+ InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
+ BodyGenCB(AllocaIP, CodeGenIP);
+
OutlineInfo OI;
OI.EntryBB = AllocaBB;
OI.ExitBB = ExitBB;
OI.OuterAllocaBB = &OuterAllocaBB;
- OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
- // The input IR here looks like the following-
- // ```
- // func @current_fn() {
- // outlined_fn(%args)
- // }
- // func @outlined_fn(%args) { ... }
- // ```
- //
- // This is changed to the following-
- //
- // ```
- // func @current_fn() {
- // runtime_call(..., wrapper_fn, ...)
- // }
- // func @wrapper_fn(..., %args) {
- // outlined_fn(%args)
- // }
- // func @outlined_fn(%args) { ... }
- // ```
+ // Insert fake values for global tid and bound tid.
+ std::stack<Instruction *> ToBeDeleted;
+ OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+ Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
+ OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+ Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
+
+ OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable {
// The stale call instruction will be replaced with a new call instruction
- // for runtime call with a wrapper function.
+ // for runtime call with the outlined function.
assert(OutlinedFn.getNumUses() == 1 &&
"there must be a single user for the outlined function");
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+ ToBeDeleted.push(StaleCI);
+
+ assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
+ "Outlined function must have two or three arguments only");
- // Create the wrapper function.
- SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
- for (auto &Arg : OutlinedFn.args())
- WrapperArgTys.push_back(Arg.getType());
- FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
- (Twine(OutlinedFn.getName()) + ".teams").str(),
- FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
- Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
- WrapperFunc->getArg(0)->setName("global_tid");
- WrapperFunc->getArg(1)->setName("bound_tid");
- if (WrapperFunc->arg_size() > 2)
- WrapperFunc->getArg(2)->setName("data");
-
- // Emit the body of the wrapper function - just a call to outlined function
- // and return statement.
- BasicBlock *WrapperEntryBB =
- BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
- Builder.SetInsertPoint(WrapperEntryBB);
- SmallVector<Value *> Args;
- for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++)
- Args.push_back(WrapperFunc->getArg(ArgIndex));
- Builder.CreateCall(&OutlinedFn, Args);
- Builder.CreateRetVoid();
-
- OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+ bool HasShared = OutlinedFn.arg_size() == 3;
+
+ OutlinedFn.getArg(0)->setName("global.tid.ptr");
+ OutlinedFn.getArg(1)->setName("bound.tid.ptr");
+ if (HasShared)
+ OutlinedFn.getArg(2)->setName("data");
// Call to the runtime function for teams in the current function.
assert(StaleCI && "Error while outlining - no CallInst user found for the "
"outlined function.");
Builder.SetInsertPoint(StaleCI);
- Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
- for (Use &Arg : StaleCI->args())
- Args.push_back(Arg);
+ SmallVector<Value *> Args = {Ident, Builder.getInt32(StaleCI->arg_size()),
+ &OutlinedFn};
+ if (HasShared)
+ Args.push_back(StaleCI->getArgOperand(2));
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
Args);
- StaleCI->eraseFromParent();
+
+ while (!ToBeDeleted.empty()) {
+ ToBeDeleted.top()->eraseFromParent();
+ ToBeDeleted.pop();
+ }
};
addOutlineInfo(std::move(OI));
- // Generate the body of teams.
- InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
- InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
- BodyGenCB(AllocaIP, CodeGenIP);
-
Builder.SetInsertPoint(ExitBB, ExitBB->begin());
return Builder.saveIP();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 643b34270c01693..c4b0389c89c7c60 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4057,25 +4057,17 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
ASSERT_NE(SrcSrc, nullptr);
// Verify the outlined function signature.
- Function *WrapperFn =
+ Function *OutlinedFn =
dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts());
- ASSERT_NE(WrapperFn, nullptr);
- EXPECT_FALSE(WrapperFn->isDeclaration());
- EXPECT_TRUE(WrapperFn->arg_size() >= 3);
- EXPECT_EQ(WrapperFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
- EXPECT_EQ(WrapperFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
- EXPECT_EQ(WrapperFn->getArg(2)->getType(),
+ ASSERT_NE(OutlinedFn, nullptr);
+ EXPECT_FALSE(OutlinedFn->isDeclaration());
+ EXPECT_TRUE(OutlinedFn->arg_size() >= 3);
+ EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
+ EXPECT_EQ(OutlinedFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
+ EXPECT_EQ(OutlinedFn->getArg(2)->getType(),
Builder.getPtrTy()); // captured args
// Check for TruncInst and ICmpInst in the outlined function.
- inst_range Instructions = instructions(WrapperFn);
- auto OutlinedFnInst = find_if(
- Instructions, [](Instruction &Inst) { return isa<CallInst>(&Inst); });
- ASSERT_NE(OutlinedFnInst, Instructions.end());
- CallInst *OutlinedFnCI = dyn_cast<CallInst>(&*OutlinedFnInst);
- ASSERT_NE(OutlinedFnCI, nullptr);
- Function *OutlinedFn = OutlinedFnCI->getCalledFunction();
-
EXPECT_TRUE(any_of(instructions(OutlinedFn),
[](Instruction &inst) { return isa<TruncInst>(&inst); }));
EXPECT_TRUE(any_of(instructions(OutlinedFn),
More information about the Mlir-commits
mailing list