[Openmp-commits] [openmp] e97e0a4 - [AbstractAttributor] Fold __kmpc_parallel_level if possible
Shilei Tian via Openmp-commits
openmp-commits at lists.llvm.org
Mon Jul 26 19:46:26 PDT 2021
Author: Shilei Tian
Date: 2021-07-26T22:46:19-04:00
New Revision: e97e0a4fad091474131ad84d6f46009bf84c5b60
URL: https://github.com/llvm/llvm-project/commit/e97e0a4fad091474131ad84d6f46009bf84c5b60
DIFF: https://github.com/llvm/llvm-project/commit/e97e0a4fad091474131ad84d6f46009bf84c5b60.diff
LOG: [AbstractAttributor] Fold __kmpc_parallel_level if possible
Similar to D105787, this patch tries to fold `__kmpc_parallel_level` if possible.
Note that `__kmpc_parallel_level` doesn't take activeness into consideration,
based on current `deviceRTLs`, its return value can be such as 0, 1, 2, instead
of 0, 129, 130, etc. that also indicate activeness.
Reviewed By: jdoerfert
Differential Revision: https://reviews.llvm.org/D106154
Added:
llvm/test/Transforms/OpenMP/parallel_level_fold.ll
Modified:
llvm/lib/Transforms/IPO/OpenMPOpt.cpp
openmp/libomptarget/deviceRTLs/common/src/parallel.cu
openmp/libomptarget/deviceRTLs/interface.h
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 5d4f4f47fb9a5..9150d951daa6b 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -519,6 +519,11 @@ struct KernelInfoState : AbstractState {
/// State to track what kernel entries can reach the associated function.
BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
+ /// State to indicate if we can track parallel level of the associated
+ /// function. We will give up tracking if we encounter unknown caller or the
+ /// caller is __kmpc_parallel_51.
+ BooleanStateWithSetVector<uint8_t> ParallelLevels;
+
/// Abstract State interface
///{
@@ -3329,8 +3334,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- if (!IsKernelEntry)
+ if (!IsKernelEntry) {
updateReachingKernelEntries(A);
+ updateParallelLevels(A);
+ }
// Callback to check a call instruction.
bool AllSPMDStatesWereFixed = true;
@@ -3386,6 +3393,49 @@ struct AAKernelInfoFunction : AAKernelInfo {
AllCallSitesKnown))
ReachingKernelEntries.indicatePessimisticFixpoint();
}
+
+ /// Update info regarding parallel levels.
+ void updateParallelLevels(Attributor &A) {
+ auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+ OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
+ OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
+
+ auto PredCallSite = [&](AbstractCallSite ACS) {
+ Function *Caller = ACS.getInstruction()->getFunction();
+
+ assert(Caller && "Caller is nullptr");
+
+ auto &CAA =
+ A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
+ if (CAA.ParallelLevels.isValidState()) {
+ // Any function that is called by `__kmpc_parallel_51` will not be
+ // folded as the parallel level in the function is updated. In order to
+ // get it right, all the analysis would depend on the implentation. That
+ // said, if in the future any change to the implementation, the analysis
+ // could be wrong. As a consequence, we are just conservative here.
+ if (Caller == Parallel51RFI.Declaration) {
+ ParallelLevels.indicatePessimisticFixpoint();
+ return true;
+ }
+
+ ParallelLevels ^= CAA.ParallelLevels;
+
+ return true;
+ }
+
+ // We lost track of the caller of the associated function, any kernel
+ // could reach now.
+ ParallelLevels.indicatePessimisticFixpoint();
+
+ return true;
+ };
+
+ bool AllCallSitesKnown = true;
+ if (!A.checkForAllCallSites(PredCallSite, *this,
+ true /* RequireAllCallSites */,
+ AllCallSitesKnown))
+ ParallelLevels.indicatePessimisticFixpoint();
+ }
};
/// The call site kernel info abstract attribute, basically, what can we say
@@ -3668,6 +3718,9 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
case OMPRTL___kmpc_is_generic_main_thread_id:
Changed |= foldIsGenericMainThread(A);
break;
+ case OMPRTL___kmpc_parallel_level:
+ Changed |= foldParallelLevel(A);
+ break;
default:
llvm_unreachable("Unhandled OpenMP runtime function!");
}
@@ -3782,6 +3835,68 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
: ChangeStatus::CHANGED;
}
+ /// Fold __kmpc_parallel_level into a constant if possible.
+ ChangeStatus foldParallelLevel(Attributor &A) {
+ Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
+
+ auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+ *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
+
+ if (!CallerKernelInfoAA.ParallelLevels.isValidState())
+ return indicatePessimisticFixpoint();
+
+ if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
+ return indicatePessimisticFixpoint();
+
+ if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
+ assert(!SimplifiedValue.hasValue() &&
+ "SimplifiedValue should keep none at this point");
+ return ChangeStatus::UNCHANGED;
+ }
+
+ unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
+ unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
+ for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
+ auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
+ DepClassTy::REQUIRED);
+ if (!AA.SPMDCompatibilityTracker.isValidState())
+ return indicatePessimisticFixpoint();
+
+ if (AA.SPMDCompatibilityTracker.isAssumed()) {
+ if (AA.SPMDCompatibilityTracker.isAtFixpoint())
+ ++KnownSPMDCount;
+ else
+ ++AssumedSPMDCount;
+ } else {
+ if (AA.SPMDCompatibilityTracker.isAtFixpoint())
+ ++KnownNonSPMDCount;
+ else
+ ++AssumedNonSPMDCount;
+ }
+ }
+
+ if ((AssumedSPMDCount + KnownSPMDCount) &&
+ (AssumedNonSPMDCount + KnownNonSPMDCount))
+ return indicatePessimisticFixpoint();
+
+ auto &Ctx = getAnchorValue().getContext();
+ // If the caller can only be reached by SPMD kernel entries, the parallel
+ // level is 1. Similarly, if the caller can only be reached by non-SPMD
+ // kernel entries, it is 0.
+ if (AssumedSPMDCount || KnownSPMDCount) {
+ assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
+ "Expected only SPMD kernels!");
+ SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
+ } else {
+ assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
+ "Expected only non-SPMD kernels!");
+ SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
+ }
+
+ return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
+ : ChangeStatus::CHANGED;
+ }
+
/// An optional value the associated value is assumed to fold to. That is, we
/// assume the associated value (which is a call) can be replaced by this
/// simplified value.
@@ -3832,6 +3947,19 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
/* UpdateAfterInit */ false);
return false;
});
+
+ auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level];
+ ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) {
+ CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI);
+ if (!CI)
+ return false;
+ A.getOrCreateAAFor<AAFoldRuntimeCall>(
+ IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
+ DepClassTy::NONE, /* ForceUpdate */ false,
+ /* UpdateAfterInit */ false);
+
+ return false;
+ });
}
// Create CallSite AA for all Getters.
diff --git a/llvm/test/Transforms/OpenMP/parallel_level_fold.ll b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll
new file mode 100644
index 0000000000000..29afba97ab56e
--- /dev/null
+++ b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll
@@ -0,0 +1,150 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-globals
+; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s
+target triple = "nvptx64"
+
+%struct.ident_t = type { i32, i32, i32, i32, i8* }
+
+ at no_spmd_exec_mode = weak constant i8 1
+ at spmd_exec_mode = weak constant i8 0
+ at parallel_exec_mode = weak constant i8 0
+ at G = external global i8
+ at llvm.compiler.used = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata"
+
+;.
+; CHECK: @[[NO_SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 1
+; CHECK: @[[SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0
+; CHECK: @[[PARALLEL_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0
+; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i8
+; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata"
+;.
+define weak void @none_spmd() {
+; CHECK-LABEL: define {{[^@]+}}@none_spmd() {
+; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
+; CHECK-NEXT: call void @none_spmd_helper()
+; CHECK-NEXT: call void @mixed_helper()
+; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
+; CHECK-NEXT: ret void
+;
+ %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
+ call void @none_spmd_helper()
+ call void @mixed_helper()
+ call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
+ ret void
+}
+
+define weak void @spmd() {
+; CHECK-LABEL: define {{[^@]+}}@spmd() {
+; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
+; CHECK-NEXT: call void @spmd_helper()
+; CHECK-NEXT: call void @mixed_helper()
+; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
+; CHECK-NEXT: ret void
+;
+ %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
+ call void @spmd_helper()
+ call void @mixed_helper()
+ call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
+ ret void
+}
+
+define weak void @parallel() {
+; CHECK-LABEL: define {{[^@]+}}@parallel() {
+; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* align 536870912 null, i1 true, i1 false, i1 false)
+; CHECK-NEXT: call void @spmd_helper()
+; CHECK-NEXT: call void @__kmpc_parallel_51(%struct.ident_t* noalias noundef align 536870912 null, i32 noundef 0, i32 noundef 0, i32 noundef 0, i32 noundef 0, i8* noalias noundef align 536870912 null, i8* noalias noundef align 536870912 null, i8** noalias noundef align 536870912 null, i64 noundef 0)
+; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
+; CHECK-NEXT: ret void
+;
+ %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
+ call void @spmd_helper()
+ call void @__kmpc_parallel_51(%struct.ident_t* null, i32 0, i32 0, i32 0, i32 0, i8* null, i8* null, i8** null, i64 0)
+ call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
+ ret void
+}
+
+define internal void @mixed_helper() {
+; CHECK-LABEL: define {{[^@]+}}@mixed_helper() {
+; CHECK-NEXT: [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level()
+; CHECK-NEXT: store i8 [[LEVEL]], i8* @G, align 1
+; CHECK-NEXT: ret void
+;
+ %level = call i8 @__kmpc_parallel_level()
+ store i8 %level, i8* @G
+ ret void
+}
+
+define internal void @none_spmd_helper() {
+; CHECK-LABEL: define {{[^@]+}}@none_spmd_helper() {
+; CHECK-NEXT: [[LEVEL12:%.*]] = call i8 @__kmpc_parallel_level()
+; CHECK-NEXT: [[C:%.*]] = icmp eq i8 [[LEVEL12]], 0
+; CHECK-NEXT: br i1 [[C]], label [[T:%.*]], label [[F:%.*]]
+; CHECK: t:
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: ret void
+; CHECK: f:
+; CHECK-NEXT: call void @bar()
+; CHECK-NEXT: ret void
+;
+ %level12 = call i8 @__kmpc_parallel_level()
+ %c = icmp eq i8 %level12, 0
+ br i1 %c, label %t, label %f
+t:
+ call void @foo()
+ ret void
+f:
+ call void @bar()
+ ret void
+}
+
+define internal void @spmd_helper() {
+; CHECK-LABEL: define {{[^@]+}}@spmd_helper() {
+; CHECK-NEXT: store i8 1, i8* @G, align 1
+; CHECK-NEXT: ret void
+;
+ %level = call i8 @__kmpc_parallel_level()
+ store i8 %level, i8* @G
+ ret void
+}
+
+define internal void @__kmpc_parallel_51(%struct.ident_t*, i32, i32, i32, i32, i8*, i8*, i8**, i64) {
+; CHECK-LABEL: define {{[^@]+}}@__kmpc_parallel_51
+; CHECK-SAME: (%struct.ident_t* noalias nocapture nofree readnone align 536870912 [[TMP0:%.*]], i32 [[TMP1:%.*]], i32 [[TMP2:%.*]], i32 [[TMP3:%.*]], i32 [[TMP4:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP5:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP6:%.*]], i8** noalias nocapture nofree readnone align 536870912 [[TMP7:%.*]], i64 [[TMP8:%.*]]) {
+; CHECK-NEXT: call void @parallel_helper()
+; CHECK-NEXT: ret void
+;
+ call void @parallel_helper()
+ ret void
+}
+
+define internal void @parallel_helper() {
+; CHECK-LABEL: define {{[^@]+}}@parallel_helper() {
+; CHECK-NEXT: [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level()
+; CHECK-NEXT: store i8 [[LEVEL]], i8* @G, align 1
+; CHECK-NEXT: ret void
+;
+ %level = call i8 @__kmpc_parallel_level()
+ store i8 %level, i8* @G
+ ret void
+}
+
+declare void @foo()
+declare void @bar()
+declare i8 @__kmpc_parallel_level()
+declare i32 @__kmpc_target_init(%struct.ident_t*, i1 zeroext, i1 zeroext, i1 zeroext) #1
+declare void @__kmpc_target_deinit(%struct.ident_t* nocapture readnone, i1 zeroext, i1 zeroext) #1
+
+!llvm.module.flags = !{!0, !1}
+!nvvm.annotations = !{!2, !3, !4}
+
+!0 = !{i32 7, !"openmp", i32 50}
+!1 = !{i32 7, !"openmp-device", i32 50}
+!2 = !{void ()* @none_spmd, !"kernel", i32 1}
+!3 = !{void ()* @spmd, !"kernel", i32 1}
+!4 = !{void ()* @parallel, !"kernel", i32 1}
+;.
+; CHECK: [[META0:![0-9]+]] = !{i32 7, !"openmp", i32 50}
+; CHECK: [[META1:![0-9]+]] = !{i32 7, !"openmp-device", i32 50}
+; CHECK: [[META2:![0-9]+]] = !{void ()* @none_spmd, !"kernel", i32 1}
+; CHECK: [[META3:![0-9]+]] = !{void ()* @spmd, !"kernel", i32 1}
+; CHECK: [[META4:![0-9]+]] = !{void ()* @parallel, !"kernel", i32 1}
+;.
diff --git a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu
index 2656f3e48ce3f..b12d5ccb5a17e 100644
--- a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu
+++ b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu
@@ -239,7 +239,7 @@ EXTERN void __kmpc_end_serialized_parallel(kmp_Ident *loc,
currTaskDescr->RestoreLoopData();
}
-EXTERN uint8_t __kmpc_parallel_level() {
+NOINLINE EXTERN uint8_t __kmpc_parallel_level() {
return parallelLevel[GetWarpId()] & (OMP_ACTIVE_PARALLEL_LEVEL - 1);
}
@@ -282,11 +282,11 @@ EXTERN void __kmpc_push_proc_bind(kmp_Ident *loc, uint32_t tid, int proc_bind) {
// parallel interface
////////////////////////////////////////////////////////////////////////////////
-EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid,
- kmp_int32 if_expr, kmp_int32 num_threads,
- int proc_bind, void *fn, void *wrapper_fn,
- void **args, size_t nargs) {
-
+NOINLINE EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid,
+ kmp_int32 if_expr,
+ kmp_int32 num_threads, int proc_bind,
+ void *fn, void *wrapper_fn, void **args,
+ size_t nargs) {
// Handle the serialized case first, same for SPMD/non-SPMD except that in
// SPMD mode we already incremented the parallel level counter, account for
// that.
diff --git a/openmp/libomptarget/deviceRTLs/interface.h b/openmp/libomptarget/deviceRTLs/interface.h
index ade1bfe2222eb..2e80dc3a82ac9 100644
--- a/openmp/libomptarget/deviceRTLs/interface.h
+++ b/openmp/libomptarget/deviceRTLs/interface.h
@@ -222,7 +222,7 @@ EXTERN void __kmpc_push_num_threads(kmp_Ident *loc, int32_t global_tid,
int32_t num_threads);
EXTERN void __kmpc_serialized_parallel(kmp_Ident *loc, uint32_t global_tid);
EXTERN void __kmpc_end_serialized_parallel(kmp_Ident *loc, uint32_t global_tid);
-EXTERN uint8_t __kmpc_parallel_level();
+NOINLINE EXTERN uint8_t __kmpc_parallel_level();
// proc bind
EXTERN void __kmpc_push_proc_bind(kmp_Ident *loc, uint32_t global_tid,
@@ -441,10 +441,11 @@ EXTERN void __kmpc_get_shared_variables(void ***GlobalArgs);
/// \param wrapper_fn The worker wrapper function of fn.
/// \param args The pointer array of arguments to fn.
/// \param nargs The number of arguments to fn.
-EXTERN void __kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid,
- kmp_int32 if_expr, kmp_int32 num_threads,
- int proc_bind, void *fn, void *wrapper_fn,
- void **args, size_t nargs);
+NOINLINE EXTERN void __kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid,
+ kmp_int32 if_expr,
+ kmp_int32 num_threads, int proc_bind,
+ void *fn, void *wrapper_fn, void **args,
+ size_t nargs);
// SPMD execution mode interrogation function.
EXTERN int8_t __kmpc_is_spmd_exec_mode();
More information about the Openmp-commits
mailing list