[llvm] b72f1ec - [openmp][mlir] Lower parallel if to new fork_call_if function.
David Truby via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 9 06:25:33 PST 2022
Author: David Truby
Date: 2022-12-09T14:23:27Z
New Revision: b72f1ec9fbb14cd7d2f5112d2c52ef5cdd1aa94a
URL: https://github.com/llvm/llvm-project/commit/b72f1ec9fbb14cd7d2f5112d2c52ef5cdd1aa94a
DIFF: https://github.com/llvm/llvm-project/commit/b72f1ec9fbb14cd7d2f5112d2c52ef5cdd1aa94a.diff
LOG: [openmp][mlir] Lower parallel if to new fork_call_if function.
This patch adds a new runtime function `fork_call_if` and uses that
to lower parallel if statements when going through OpenMPIRBuilder.
This fixes an issue where the OpenMPIRBuilder passes all arguments to
fork_call as a struct but this struct is not filled corretly in the
non-if branch by handling the fork inside the runtime.
Differential Revision: https://reviews.llvm.org/D138495
Added:
Modified:
llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
mlir/test/Target/LLVMIR/openmp-llvm.mlir
openmp/runtime/src/kmp.h
openmp/runtime/src/kmp_csupport.cpp
openmp/runtime/test/lit.cfg
openmp/runtime/test/parallel/omp_parallel_if.c
Removed:
################################################################################
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index 87ac4d59d59b9..bef838afb9f8e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -203,6 +203,8 @@ __OMP_RTL(__kmpc_flush, false, Void, IdentPtr)
__OMP_RTL(__kmpc_global_thread_num, false, Int32, IdentPtr)
__OMP_RTL(__kmpc_get_hardware_thread_id_in_block, false, Int32, )
__OMP_RTL(__kmpc_fork_call, true, Void, IdentPtr, Int32, ParallelTaskPtr)
+__OMP_RTL(__kmpc_fork_call_if, false, Void, IdentPtr, Int32, ParallelTaskPtr,
+ Int32, VoidPtr)
__OMP_RTL(__kmpc_omp_taskwait, false, Int32, IdentPtr, Int32)
__OMP_RTL(__kmpc_omp_taskwait_51, false, Int32, IdentPtr, Int32, Int32)
__OMP_RTL(__kmpc_omp_taskyield, false, Int32, IdentPtr, Int32, /* Int */ Int32)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 5e08c299c96df..f002644739185 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -914,34 +914,21 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr");
- // If there is an if condition we actually use the TIDAddr and ZeroAddr in the
- // program, otherwise we only need them for modeling purposes to get the
- // associated arguments in the outlined function. In the former case,
- // initialize the allocas properly, in the latter case, delete them later.
- if (IfCondition) {
- Builder.CreateStore(Constant::getNullValue(Int32), TIDAddr);
- Builder.CreateStore(Constant::getNullValue(Int32), ZeroAddr);
- } else {
- ToBeDeleted.push_back(TIDAddr);
- ToBeDeleted.push_back(ZeroAddr);
- }
+ // We only need TIDAddr and ZeroAddr for modeling purposes to get the
+ // associated arguments in the outlined function, so we delete them later.
+ ToBeDeleted.push_back(TIDAddr);
+ ToBeDeleted.push_back(ZeroAddr);
// Create an artificial insertion point that will also ensure the blocks we
// are about to split are not degenerated.
auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
- Instruction *ThenTI = UI, *ElseTI = nullptr;
- if (IfCondition)
- SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
-
- BasicBlock *ThenBB = ThenTI->getParent();
- BasicBlock *PRegEntryBB = ThenBB->splitBasicBlock(ThenTI, "omp.par.entry");
- BasicBlock *PRegBodyBB =
- PRegEntryBB->splitBasicBlock(ThenTI, "omp.par.region");
+ BasicBlock *EntryBB = UI->getParent();
+ BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
+ BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
BasicBlock *PRegPreFiniBB =
- PRegBodyBB->splitBasicBlock(ThenTI, "omp.par.pre_finalize");
- BasicBlock *PRegExitBB =
- PRegPreFiniBB->splitBasicBlock(ThenTI, "omp.par.exit");
+ PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
+ BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
auto FiniCBWrapper = [&](InsertPointTy IP) {
// Hide "open-ended" blocks from the given FiniCB by setting the right jump
@@ -975,7 +962,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
ToBeDeleted.push_back(ZeroAddrUse);
- // ThenBB
+ // EntryBB
// |
// V
// PRegionEntryBB <- Privatization allocas are placed here.
@@ -998,8 +985,12 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
BodyGenCB(InnerAllocaIP, CodeGenIP);
LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
+ FunctionCallee RTLFn;
+ if (IfCondition)
+ RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
+ else
+ RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
- FunctionCallee RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
llvm::LLVMContext &Ctx = F->getContext();
@@ -1034,15 +1025,30 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
CI->getParent()->setName("omp_parallel");
Builder.SetInsertPoint(CI);
- // Build call __kmpc_fork_call(Ident, n, microtask, var1, .., varn);
+ // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
Value *ForkCallArgs[] = {
Ident, Builder.getInt32(NumCapturedVars),
Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)};
SmallVector<Value *, 16> RealArgs;
RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
+ if (IfCondition) {
+ Value *Cond = Builder.CreateSExtOrTrunc(IfCondition,
+ Type::getInt32Ty(M.getContext()));
+ RealArgs.push_back(Cond);
+ }
RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
+ // __kmpc_fork_call_if always expects a void ptr as the last argument
+ // If there are no arguments, pass a null pointer.
+ auto PtrTy = Type::getInt8PtrTy(M.getContext());
+ if (IfCondition && NumCapturedVars == 0) {
+ llvm::Value *Void = ConstantPointerNull::get(PtrTy);
+ RealArgs.push_back(Void);
+ }
+ if (IfCondition && RealArgs.back()->getType() != PtrTy)
+ RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
+
Builder.CreateCall(RTLFn, RealArgs);
LLVM_DEBUG(dbgs() << "With fork_call placed: "
@@ -1055,35 +1061,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr);
- // If no "if" clause was present we do not need the call created during
- // outlining, otherwise we reuse it in the serialized parallel region.
- if (!ElseTI) {
- CI->eraseFromParent();
- } else {
-
- // If an "if" clause was present we are now generating the serialized
- // version into the "else" branch.
- Builder.SetInsertPoint(ElseTI);
-
- // Build calls __kmpc_serialized_parallel(&Ident, GTid);
- Value *SerializedParallelCallArgs[] = {Ident, ThreadID};
- Builder.CreateCall(
- getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_serialized_parallel),
- SerializedParallelCallArgs);
-
- // OutlinedFn(>id, &zero, CapturedStruct);
- CI->removeFromParent();
- Builder.Insert(CI);
-
- // __kmpc_end_serialized_parallel(&Ident, GTid);
- Value *EndArgs[] = {Ident, ThreadID};
- Builder.CreateCall(
- getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_serialized_parallel),
- EndArgs);
-
- LLVM_DEBUG(dbgs() << "With serialized parallel region: "
- << *Builder.GetInsertBlock()->getParent() << "\n");
- }
+ CI->eraseFromParent();
for (Instruction *I : ToBeDeleted)
I->eraseFromParent();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index c5a9e6e29f3e6..30d8aee0b4e87 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -986,38 +986,22 @@ TEST_F(OpenMPIRBuilderTest, ParallelIfCond) {
EXPECT_EQ(OutlinedFn->arg_size(), 3U);
EXPECT_EQ(&OutlinedFn->getEntryBlock(), PrivAI->getParent());
- ASSERT_EQ(OutlinedFn->getNumUses(), 2U);
+ ASSERT_EQ(OutlinedFn->getNumUses(), 1U);
- CallInst *DirectCI = nullptr;
CallInst *ForkCI = nullptr;
for (User *Usr : OutlinedFn->users()) {
- if (isa<CallInst>(Usr)) {
- ASSERT_EQ(DirectCI, nullptr);
- DirectCI = cast<CallInst>(Usr);
- } else {
- ASSERT_TRUE(isa<ConstantExpr>(Usr));
- ASSERT_EQ(Usr->getNumUses(), 1U);
- ASSERT_TRUE(isa<CallInst>(Usr->user_back()));
- ForkCI = cast<CallInst>(Usr->user_back());
- }
+ ASSERT_TRUE(isa<ConstantExpr>(Usr));
+ ASSERT_EQ(Usr->getNumUses(), 1U);
+ ASSERT_TRUE(isa<CallInst>(Usr->user_back()));
+ ForkCI = cast<CallInst>(Usr->user_back());
}
- EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call");
- EXPECT_EQ(ForkCI->arg_size(), 4U);
+ EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call_if");
+ EXPECT_EQ(ForkCI->arg_size(), 5U);
EXPECT_TRUE(isa<GlobalVariable>(ForkCI->getArgOperand(0)));
EXPECT_EQ(ForkCI->getArgOperand(1),
ConstantInt::get(Type::getInt32Ty(Ctx), 1));
- Value *StoredForkArg =
- findStoredValueInAggregateAt(Ctx, ForkCI->getArgOperand(3), 0);
- EXPECT_EQ(StoredForkArg, F->arg_begin());
-
- EXPECT_EQ(DirectCI->getCalledFunction(), OutlinedFn);
- EXPECT_EQ(DirectCI->arg_size(), 3U);
- EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(0)));
- EXPECT_TRUE(isa<AllocaInst>(DirectCI->getArgOperand(1)));
- Value *StoredDirectArg =
- findStoredValueInAggregateAt(Ctx, DirectCI->getArgOperand(2), 0);
- EXPECT_EQ(StoredDirectArg, F->arg_begin());
+ EXPECT_EQ(ForkCI->getArgOperand(3)->getType(), Type::getInt32Ty(Ctx));
}
TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) {
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 702b8f309c619..94e37c1e035ff 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -151,33 +151,19 @@ llvm.func @test_omp_parallel_num_threads_3() -> () {
// CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]])
llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
-// Check that the allocas are emitted by the OpenMPIRBuilder at the top of the
-// function, before the condition. Allocas are only emitted by the builder when
-// the `if` clause is present. We match specific SSA value names since LLVM
-// actually produces those names.
-// CHECK: %tid.addr{{.*}} = alloca i32
-// CHECK: %zero.addr{{.*}} = alloca i32
-
-// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
%0 = llvm.mlir.constant(0 : index) : i32
%1 = llvm.icmp "slt" %arg0, %0 : i32
+// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0
+
// CHECK: %[[GTN_IF_1:.*]] = call i32 @__kmpc_global_thread_num(ptr @[[SI_VAR_IF_1:.*]])
-// CHECK: br i1 %[[IF_COND_VAR_1]], label %[[IF_COND_TRUE_BLOCK_1:.*]], label %[[IF_COND_FALSE_BLOCK_1:.*]]
-// CHECK: [[IF_COND_TRUE_BLOCK_1]]:
// CHECK: br label %[[OUTLINED_CALL_IF_BLOCK_1:.*]]
// CHECK: [[OUTLINED_CALL_IF_BLOCK_1]]:
-// CHECK: call void {{.*}} @__kmpc_fork_call(ptr @[[SI_VAR_IF_1]], {{.*}} @[[OMP_OUTLINED_FN_IF_1:.*]])
+// CHECK: %[[I32_IF_COND_VAR_1:.*]] = sext i1 %[[IF_COND_VAR_1]] to i32
+// CHECK: call void @__kmpc_fork_call_if(ptr @[[SI_VAR_IF_1]], i32 0, ptr @[[OMP_OUTLINED_FN_IF_1:.*]], i32 %[[I32_IF_COND_VAR_1]], ptr null)
// CHECK: br label %[[OUTLINED_EXIT_IF_1:.*]]
// CHECK: [[OUTLINED_EXIT_IF_1]]:
-// CHECK: br label %[[OUTLINED_EXIT_IF_2:.*]]
-// CHECK: [[OUTLINED_EXIT_IF_2]]:
// CHECK: br label %[[RETURN_BLOCK_IF_1:.*]]
-// CHECK: [[IF_COND_FALSE_BLOCK_1]]:
-// CHECK: call void @__kmpc_serialized_parallel(ptr @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]])
-// CHECK: call void @[[OMP_OUTLINED_FN_IF_1]]
-// CHECK: call void @__kmpc_end_serialized_parallel(ptr @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]])
-// CHECK: br label %[[RETURN_BLOCK_IF_1]]
omp.parallel if(%1 : i1) {
omp.barrier
omp.terminator
@@ -193,58 +179,6 @@ llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () {
// -----
-// CHECK-LABEL: @test_nested_alloca_ip
-llvm.func @test_nested_alloca_ip(%arg0: i32) -> () {
-
- // Check that the allocas are emitted by the OpenMPIRBuilder at the top of
- // the function, before the condition. Allocas are only emitted by the
- // builder when the `if` clause is present. We match specific SSA value names
- // since LLVM actually produces those names and ensure they come before the
- // "icmp" that is the first operation we emit.
- // CHECK: %tid.addr{{.*}} = alloca i32
- // CHECK: %zero.addr{{.*}} = alloca i32
- // CHECK: icmp slt i32 %{{.*}}, 0
- %0 = llvm.mlir.constant(0 : index) : i32
- %1 = llvm.icmp "slt" %arg0, %0 : i32
-
- omp.parallel if(%1 : i1) {
- // The "parallel" operation will be outlined, check the the function is
- // produced. Inside that function, further allocas should be placed before
- // another "icmp".
- // CHECK: define
- // CHECK: %tid.addr{{.*}} = alloca i32
- // CHECK: %zero.addr{{.*}} = alloca i32
- // CHECK: icmp slt i32 %{{.*}}, 1
- %2 = llvm.mlir.constant(1 : index) : i32
- %3 = llvm.icmp "slt" %arg0, %2 : i32
-
- omp.parallel if(%3 : i1) {
- // One more nesting level.
- // CHECK: define
- // CHECK: %tid.addr{{.*}} = alloca i32
- // CHECK: %zero.addr{{.*}} = alloca i32
- // CHECK: icmp slt i32 %{{.*}}, 2
-
- %4 = llvm.mlir.constant(2 : index) : i32
- %5 = llvm.icmp "slt" %arg0, %4 : i32
-
- omp.parallel if(%5 : i1) {
- omp.barrier
- omp.terminator
- }
-
- omp.barrier
- omp.terminator
- }
- omp.barrier
- omp.terminator
- }
-
- llvm.return
-}
-
-// -----
-
// CHECK-LABEL: define void @test_omp_parallel_3()
llvm.func @test_omp_parallel_3() -> () {
// CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(ptr @{{[0-9]+}})
diff --git a/openmp/runtime/src/kmp.h b/openmp/runtime/src/kmp.h
index f2d030f6b2dc2..99bd39b916142 100644
--- a/openmp/runtime/src/kmp.h
+++ b/openmp/runtime/src/kmp.h
@@ -3901,6 +3901,9 @@ KMP_EXPORT kmp_int32 __kmpc_bound_num_threads(ident_t *);
KMP_EXPORT kmp_int32 __kmpc_ok_to_fork(ident_t *);
KMP_EXPORT void __kmpc_fork_call(ident_t *, kmp_int32 nargs,
kmpc_micro microtask, ...);
+KMP_EXPORT void __kmpc_fork_call_if(ident_t *loc, kmp_int32 nargs,
+ kmpc_micro microtask, kmp_int32 cond,
+ void *args);
KMP_EXPORT void __kmpc_serialized_parallel(ident_t *, kmp_int32 global_tid);
KMP_EXPORT void __kmpc_end_serialized_parallel(ident_t *, kmp_int32 global_tid);
diff --git a/openmp/runtime/src/kmp_csupport.cpp b/openmp/runtime/src/kmp_csupport.cpp
index 97b15be967f0c..64b9d162c8921 100644
--- a/openmp/runtime/src/kmp_csupport.cpp
+++ b/openmp/runtime/src/kmp_csupport.cpp
@@ -330,6 +330,37 @@ void __kmpc_fork_call(ident_t *loc, kmp_int32 argc, kmpc_micro microtask, ...) {
#endif // KMP_STATS_ENABLED
}
+/*!
+ at ingroup PARALLEL
+ at param loc source location information
+ at param microtask pointer to callback routine consisting of outlined parallel
+construct
+ at param cond condition for running in parallel
+ at param args struct of pointers to shared variables that aren't global
+
+Perform a fork only if the condition is true.
+*/
+void __kmpc_fork_call_if(ident_t *loc, kmp_int32 argc, kmpc_micro microtask,
+ kmp_int32 cond, void *args) {
+ int gtid = __kmp_entry_gtid();
+ int zero = 0;
+ if (cond) {
+ if (args)
+ __kmpc_fork_call(loc, argc, microtask, args);
+ else
+ __kmpc_fork_call(loc, argc, microtask);
+ } else {
+ __kmpc_serialized_parallel(loc, gtid);
+
+ if (args)
+ microtask(>id, &zero, args);
+ else
+ microtask(>id, &zero);
+
+ __kmpc_end_serialized_parallel(loc, gtid);
+ }
+}
+
/*!
@ingroup PARALLEL
@param loc source location information
diff --git a/openmp/runtime/test/lit.cfg b/openmp/runtime/test/lit.cfg
index c1cf24a42283e..f49f39adee545 100644
--- a/openmp/runtime/test/lit.cfg
+++ b/openmp/runtime/test/lit.cfg
@@ -133,6 +133,8 @@ if 'INTEL_LICENSE_FILE' in os.environ:
# substitutions
config.substitutions.append(("%libomp-compile-and-run", \
"%libomp-compile && %libomp-run"))
+config.substitutions.append(("%libomp-irbuilder-compile-and-run", \
+ "%libomp-irbuilder-compile && %libomp-run"))
config.substitutions.append(("%libomp-c99-compile-and-run", \
"%libomp-c99-compile && %libomp-run"))
config.substitutions.append(("%libomp-cxx-compile-and-run", \
@@ -143,6 +145,8 @@ config.substitutions.append(("%libomp-cxx-compile", \
"%clangXX %openmp_flags %flags -std=c++17 %s -o %t" + libs))
config.substitutions.append(("%libomp-compile", \
"%clang %openmp_flags %flags %s -o %t" + libs))
+config.substitutions.append(("%libomp-irbuilder-compile", \
+ "%clang %openmp_flags %flags -fopenmp-enable-irbuilder %s -o %t" + libs))
config.substitutions.append(("%libomp-c99-compile", \
"%clang %openmp_flags %flags -std=c99 %s -o %t" + libs))
config.substitutions.append(("%libomp-run", "%t"))
diff --git a/openmp/runtime/test/parallel/omp_parallel_if.c b/openmp/runtime/test/parallel/omp_parallel_if.c
index abbf3cd48995e..7a924020f3349 100644
--- a/openmp/runtime/test/parallel/omp_parallel_if.c
+++ b/openmp/runtime/test/parallel/omp_parallel_if.c
@@ -1,4 +1,5 @@
// RUN: %libomp-compile-and-run
+// RUN: %libomp-irbuilder-compile-and-run
#include <stdio.h>
#include "omp_testsuite.h"
More information about the llvm-commits
mailing list