[llvm] [OpenMP] Move KernelInfoState and AAKernelInfo to OpenMPOpt.h (PR #71878)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 23 10:10:06 PST 2024
https://github.com/nmustakin updated https://github.com/llvm/llvm-project/pull/71878
>From 96f80679140e96b58640c0a14e3d16bb57876804 Mon Sep 17 00:00:00 2001
From: Nafis Mustakin <nmust004 at ucr.edu>
Date: Thu, 9 Nov 2023 15:28:57 -0800
Subject: [PATCH] Move KernelInfoState and AAKernelInfo to
llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
Move functions to AAKernelInfo
Fix formatting issues
Update OpenMPOpt.cpp
Update OpenMPOpt.cpp
---
llvm/include/llvm/Transforms/IPO/OpenMPOpt.h | 218 +++++++++++++++++
llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 232 ++-----------------
2 files changed, 232 insertions(+), 218 deletions(-)
diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 2499c2bbccf4554..5bef03c0a1fce35 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -12,6 +12,7 @@
#include "llvm/Analysis/CGSCCPassManager.h"
#include "llvm/Analysis/LazyCallGraph.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/Transforms/IPO/Attributor.h"
namespace llvm {
@@ -62,6 +63,223 @@ class OpenMPOptCGSCCPass : public PassInfoMixin<OpenMPOptCGSCCPass> {
const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
};
+template <typename Ty, bool InsertInvalidates = true>
+struct BooleanStateWithSetVector : public BooleanState {
+ bool contains(const Ty &Elem) const { return Set.contains(Elem); }
+ bool insert(const Ty &Elem) {
+ if (InsertInvalidates)
+ BooleanState::indicatePessimisticFixpoint();
+ return Set.insert(Elem);
+ }
+
+ const Ty &operator[](int Idx) const { return Set[Idx]; }
+ bool operator==(const BooleanStateWithSetVector &RHS) const {
+ return BooleanState::operator==(RHS) && Set == RHS.Set;
+ }
+ bool operator!=(const BooleanStateWithSetVector &RHS) const {
+ return !(*this == RHS);
+ }
+
+ bool empty() const { return Set.empty(); }
+ size_t size() const { return Set.size(); }
+
+ /// "Clamp" this state with \p RHS.
+ BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
+ BooleanState::operator^=(RHS);
+ Set.insert(RHS.Set.begin(), RHS.Set.end());
+ return *this;
+ }
+
+private:
+ /// A set to keep track of elements.
+ SetVector<Ty> Set;
+
+public:
+ typename decltype(Set)::iterator begin() { return Set.begin(); }
+ typename decltype(Set)::iterator end() { return Set.end(); }
+ typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
+ typename decltype(Set)::const_iterator end() const { return Set.end(); }
+};
+
+template <typename Ty, bool InsertInvalidates = true>
+using BooleanStateWithPtrSetVector =
+ BooleanStateWithSetVector<Ty *, InsertInvalidates>;
+
+struct KernelInfoState : AbstractState {
+ /// Flag to track if we reached a fixpoint.
+ bool IsAtFixpoint = false;
+
+ /// The parallel regions (identified by the outlined parallel functions) that
+ /// can be reached from the associated function.
+ BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
+ ReachedKnownParallelRegions;
+
+ /// State to track what parallel region we might reach.
+ BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
+
+ /// State to track if we are in SPMD-mode, assumed or know, and why we decided
+ /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
+ /// false.
+ BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
+
+ /// The __kmpc_target_init call in this kernel, if any. If we find more than
+ /// one we abort as the kernel is malformed.
+ CallBase *KernelInitCB = nullptr;
+
+ /// The constant kernel environement as taken from and passed to
+ /// __kmpc_target_init.
+ ConstantStruct *KernelEnvC = nullptr;
+
+ /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
+ /// one we abort as the kernel is malformed.
+ CallBase *KernelDeinitCB = nullptr;
+
+ /// Flag to indicate if the associated function is a kernel entry.
+ bool IsKernelEntry = false;
+
+ /// State to track what kernel entries can reach the associated function.
+ BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
+
+ /// State to indicate if we can track parallel level of the associated
+ /// function. We will give up tracking if we encounter unknown caller or the
+ /// caller is __kmpc_parallel_51.
+ BooleanStateWithSetVector<uint8_t> ParallelLevels;
+
+ /// Flag that indicates if the kernel has nested Parallelism
+ bool NestedParallelism = false;
+
+ /// Abstract State interface
+ ///{
+
+ KernelInfoState() = default;
+ KernelInfoState(bool BestState) {
+ if (!BestState)
+ indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractState::isValidState(...)
+ bool isValidState() const override { return true; }
+
+ /// See AbstractState::isAtFixpoint(...)
+ bool isAtFixpoint() const override { return IsAtFixpoint; }
+
+ /// See AbstractState::indicatePessimisticFixpoint(...)
+ ChangeStatus indicatePessimisticFixpoint() override {
+ IsAtFixpoint = true;
+ ParallelLevels.indicatePessimisticFixpoint();
+ ReachingKernelEntries.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ ReachedKnownParallelRegions.indicatePessimisticFixpoint();
+ ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
+ NestedParallelism = true;
+ return ChangeStatus::CHANGED;
+ }
+
+ /// See AbstractState::indicateOptimisticFixpoint(...)
+ ChangeStatus indicateOptimisticFixpoint() override {
+ IsAtFixpoint = true;
+ ParallelLevels.indicateOptimisticFixpoint();
+ ReachingKernelEntries.indicateOptimisticFixpoint();
+ SPMDCompatibilityTracker.indicateOptimisticFixpoint();
+ ReachedKnownParallelRegions.indicateOptimisticFixpoint();
+ ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// Return the assumed state
+ KernelInfoState &getAssumed() { return *this; }
+ const KernelInfoState &getAssumed() const { return *this; }
+
+ bool operator==(const KernelInfoState &RHS) const {
+ if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
+ return false;
+ if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
+ return false;
+ if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
+ return false;
+ if (ReachingKernelEntries != RHS.ReachingKernelEntries)
+ return false;
+ if (ParallelLevels != RHS.ParallelLevels)
+ return false;
+ if (NestedParallelism != RHS.NestedParallelism)
+ return false;
+ return true;
+ }
+
+ /// Returns true if this kernel contains any OpenMP parallel regions.
+ bool mayContainParallelRegion() {
+ return !ReachedKnownParallelRegions.empty() ||
+ !ReachedUnknownParallelRegions.empty();
+ }
+
+ /// Return empty set as the best state of potential values.
+ static KernelInfoState getBestState() { return KernelInfoState(true); }
+
+ static KernelInfoState getBestState(KernelInfoState &KIS) {
+ return getBestState();
+ }
+
+ /// Return full set as the worst state of potential values.
+ static KernelInfoState getWorstState() { return KernelInfoState(false); }
+
+ /// "Clamp" this state with \p KIS.
+ KernelInfoState operator^=(const KernelInfoState &KIS) {
+ // Do not merge two different _init and _deinit call sites.
+ if (KIS.KernelInitCB) {
+ if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
+ llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+ "assumptions.");
+ KernelInitCB = KIS.KernelInitCB;
+ }
+ if (KIS.KernelDeinitCB) {
+ if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
+ llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+ "assumptions.");
+ KernelDeinitCB = KIS.KernelDeinitCB;
+ }
+ if (KIS.KernelEnvC) {
+ if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
+ llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+ "assumptions.");
+ KernelEnvC = KIS.KernelEnvC;
+ }
+ SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
+ ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
+ ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
+ NestedParallelism |= KIS.NestedParallelism;
+ return *this;
+ }
+
+ KernelInfoState operator&=(const KernelInfoState &KIS) {
+ return (*this ^= KIS);
+ }
+};
+
+struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
+ using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
+ AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+ /// Public getter for ReachingKernelEntries
+ virtual BooleanStateWithPtrSetVector<Function, false>
+ getReachingKernels() = 0;
+
+ /// Create an abstract attribute biew for the position \p IRP.
+ static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
+
+ /// This function should return true if the type of the \p AA is AAKernelInfo
+ static bool classof(const AbstractAttribute *AA) {
+ return (AA->getIdAddr() == &ID);
+ }
+
+ /// See AbstractAttribute::getName()
+ const std::string getName() const override { return "AAKernelInfo"; }
+
+ /// See AbstractAttribute::getIdAddr()
+ const char *getIdAddr() const override { return &ID; }
+
+ static const char ID;
+};
+
} // end namespace llvm
#endif // LLVM_TRANSFORMS_IPO_OPENMPOPT_H
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 4176d561363fbd9..fe3d664d6fdb0bf 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -638,200 +638,6 @@ struct OMPInformationCache : public InformationCache {
bool OpenMPPostLink = false;
};
-template <typename Ty, bool InsertInvalidates = true>
-struct BooleanStateWithSetVector : public BooleanState {
- bool contains(const Ty &Elem) const { return Set.contains(Elem); }
- bool insert(const Ty &Elem) {
- if (InsertInvalidates)
- BooleanState::indicatePessimisticFixpoint();
- return Set.insert(Elem);
- }
-
- const Ty &operator[](int Idx) const { return Set[Idx]; }
- bool operator==(const BooleanStateWithSetVector &RHS) const {
- return BooleanState::operator==(RHS) && Set == RHS.Set;
- }
- bool operator!=(const BooleanStateWithSetVector &RHS) const {
- return !(*this == RHS);
- }
-
- bool empty() const { return Set.empty(); }
- size_t size() const { return Set.size(); }
-
- /// "Clamp" this state with \p RHS.
- BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
- BooleanState::operator^=(RHS);
- Set.insert(RHS.Set.begin(), RHS.Set.end());
- return *this;
- }
-
-private:
- /// A set to keep track of elements.
- SetVector<Ty> Set;
-
-public:
- typename decltype(Set)::iterator begin() { return Set.begin(); }
- typename decltype(Set)::iterator end() { return Set.end(); }
- typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
- typename decltype(Set)::const_iterator end() const { return Set.end(); }
-};
-
-template <typename Ty, bool InsertInvalidates = true>
-using BooleanStateWithPtrSetVector =
- BooleanStateWithSetVector<Ty *, InsertInvalidates>;
-
-struct KernelInfoState : AbstractState {
- /// Flag to track if we reached a fixpoint.
- bool IsAtFixpoint = false;
-
- /// The parallel regions (identified by the outlined parallel functions) that
- /// can be reached from the associated function.
- BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
- ReachedKnownParallelRegions;
-
- /// State to track what parallel region we might reach.
- BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
-
- /// State to track if we are in SPMD-mode, assumed or know, and why we decided
- /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
- /// false.
- BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
-
- /// The __kmpc_target_init call in this kernel, if any. If we find more than
- /// one we abort as the kernel is malformed.
- CallBase *KernelInitCB = nullptr;
-
- /// The constant kernel environement as taken from and passed to
- /// __kmpc_target_init.
- ConstantStruct *KernelEnvC = nullptr;
-
- /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
- /// one we abort as the kernel is malformed.
- CallBase *KernelDeinitCB = nullptr;
-
- /// Flag to indicate if the associated function is a kernel entry.
- bool IsKernelEntry = false;
-
- /// State to track what kernel entries can reach the associated function.
- BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
-
- /// State to indicate if we can track parallel level of the associated
- /// function. We will give up tracking if we encounter unknown caller or the
- /// caller is __kmpc_parallel_51.
- BooleanStateWithSetVector<uint8_t> ParallelLevels;
-
- /// Flag that indicates if the kernel has nested Parallelism
- bool NestedParallelism = false;
-
- /// Abstract State interface
- ///{
-
- KernelInfoState() = default;
- KernelInfoState(bool BestState) {
- if (!BestState)
- indicatePessimisticFixpoint();
- }
-
- /// See AbstractState::isValidState(...)
- bool isValidState() const override { return true; }
-
- /// See AbstractState::isAtFixpoint(...)
- bool isAtFixpoint() const override { return IsAtFixpoint; }
-
- /// See AbstractState::indicatePessimisticFixpoint(...)
- ChangeStatus indicatePessimisticFixpoint() override {
- IsAtFixpoint = true;
- ParallelLevels.indicatePessimisticFixpoint();
- ReachingKernelEntries.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- ReachedKnownParallelRegions.indicatePessimisticFixpoint();
- ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
- NestedParallelism = true;
- return ChangeStatus::CHANGED;
- }
-
- /// See AbstractState::indicateOptimisticFixpoint(...)
- ChangeStatus indicateOptimisticFixpoint() override {
- IsAtFixpoint = true;
- ParallelLevels.indicateOptimisticFixpoint();
- ReachingKernelEntries.indicateOptimisticFixpoint();
- SPMDCompatibilityTracker.indicateOptimisticFixpoint();
- ReachedKnownParallelRegions.indicateOptimisticFixpoint();
- ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
- return ChangeStatus::UNCHANGED;
- }
-
- /// Return the assumed state
- KernelInfoState &getAssumed() { return *this; }
- const KernelInfoState &getAssumed() const { return *this; }
-
- bool operator==(const KernelInfoState &RHS) const {
- if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
- return false;
- if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
- return false;
- if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
- return false;
- if (ReachingKernelEntries != RHS.ReachingKernelEntries)
- return false;
- if (ParallelLevels != RHS.ParallelLevels)
- return false;
- if (NestedParallelism != RHS.NestedParallelism)
- return false;
- return true;
- }
-
- /// Returns true if this kernel contains any OpenMP parallel regions.
- bool mayContainParallelRegion() {
- return !ReachedKnownParallelRegions.empty() ||
- !ReachedUnknownParallelRegions.empty();
- }
-
- /// Return empty set as the best state of potential values.
- static KernelInfoState getBestState() { return KernelInfoState(true); }
-
- static KernelInfoState getBestState(KernelInfoState &KIS) {
- return getBestState();
- }
-
- /// Return full set as the worst state of potential values.
- static KernelInfoState getWorstState() { return KernelInfoState(false); }
-
- /// "Clamp" this state with \p KIS.
- KernelInfoState operator^=(const KernelInfoState &KIS) {
- // Do not merge two different _init and _deinit call sites.
- if (KIS.KernelInitCB) {
- if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
- llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
- "assumptions.");
- KernelInitCB = KIS.KernelInitCB;
- }
- if (KIS.KernelDeinitCB) {
- if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
- llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
- "assumptions.");
- KernelDeinitCB = KIS.KernelDeinitCB;
- }
- if (KIS.KernelEnvC) {
- if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
- llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
- "assumptions.");
- KernelEnvC = KIS.KernelEnvC;
- }
- SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
- ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
- ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
- NestedParallelism |= KIS.NestedParallelism;
- return *this;
- }
-
- KernelInfoState operator&=(const KernelInfoState &KIS) {
- return (*this ^= KIS);
- }
-
- ///}
-};
-
/// Used to map the values physically (in the IR) stored in an offload
/// array, to a vector in memory.
struct OffloadArray {
@@ -3599,9 +3405,9 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
unsigned SharedMemoryUsed = 0;
};
-struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
- using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
- AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+struct AAKernelInfoImpl : AAKernelInfo {
+ AAKernelInfoImpl(const IRPosition &IRP, Attributor &A)
+ : AAKernelInfo(IRP, A) {}
/// The callee value is tracked beyond a simple stripPointerCasts, so we allow
/// unknown callees.
@@ -3637,28 +3443,17 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
", NestedPar: " + (NestedParallelism ? "yes" : "no");
}
- /// Create an abstract attribute biew for the position \p IRP.
- static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
-
- /// See AbstractAttribute::getName()
- const std::string getName() const override { return "AAKernelInfo"; }
-
- /// See AbstractAttribute::getIdAddr()
- const char *getIdAddr() const override { return &ID; }
-
- /// This function should return true if the type of the \p AA is AAKernelInfo
- static bool classof(const AbstractAttribute *AA) {
- return (AA->getIdAddr() == &ID);
+ /// Return the ReachingKernelEntries
+ BooleanStateWithPtrSetVector<Function, false> getReachingKernels() override {
+ return ReachingKernelEntries;
}
-
- static const char ID;
};
/// The function kernel info abstract attribute, basically, what can we say
/// about a function with regards to the KernelInfoState.
-struct AAKernelInfoFunction : AAKernelInfo {
+struct AAKernelInfoFunction : AAKernelInfoImpl {
AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
- : AAKernelInfo(IRP, A) {}
+ : AAKernelInfoImpl(IRP, A) {}
SmallPtrSet<Instruction *, 4> GuardedInstructions;
@@ -4862,13 +4657,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
/// The call site kernel info abstract attribute, basically, what can we say
/// about a call site with regards to the KernelInfoState. For now this simply
/// forwards the information from the callee.
-struct AAKernelInfoCallSite : AAKernelInfo {
+struct AAKernelInfoCallSite : AAKernelInfoImpl {
AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
- : AAKernelInfo(IRP, A) {}
+ : AAKernelInfoImpl(IRP, A) {}
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- AAKernelInfo::initialize(A);
+ AAKernelInfoImpl::initialize(A);
CallBase &CB = cast<CallBase>(getAssociatedValue());
auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
@@ -4890,8 +4685,9 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// Next we check if we know the callee. If it is a known OpenMP function
// we will handle them explicitly in the switch below. If it is not, we
- // will use an AAKernelInfo object on the callee to gather information and
- // merge that into the current state. The latter happens in the updateImpl.
+ // will use an AAKernelInfo object on the callee to gather information
+ // and merge that into the current state. The latter happens in the
+ // updateImpl.
auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
More information about the llvm-commits
mailing list