[llvm] [OpenMP] Move KernelInfoState and AAKernelInfo to OpenMPOpt.h (PR #71878)
Johannes Doerfert via llvm-commits
llvm-commits at lists.llvm.org
Mon Dec 11 14:16:08 PST 2023
https://github.com/jdoerfert updated https://github.com/llvm/llvm-project/pull/71878
>From ce96bdb47d3274f3c11a20b6572c89bf4f467c48 Mon Sep 17 00:00:00 2001
From: Nafis Mustakin <nmust004 at ucr.edu>
Date: Thu, 9 Nov 2023 15:28:57 -0800
Subject: [PATCH 1/4] Move KernelInfoState and AAKernelInfo to
llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
---
llvm/include/llvm/Transforms/IPO/OpenMPOpt.h | 202 ++++++++++++++
llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 274 +++----------------
2 files changed, 246 insertions(+), 230 deletions(-)
diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 2499c2bbccf455..049c1458bb9d63 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -12,6 +12,7 @@
#include "llvm/Analysis/CGSCCPassManager.h"
#include "llvm/Analysis/LazyCallGraph.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/Transforms/IPO/Attributor.h"
namespace llvm {
@@ -62,6 +63,207 @@ class OpenMPOptCGSCCPass : public PassInfoMixin<OpenMPOptCGSCCPass> {
const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
};
+template <typename Ty, bool InsertInvalidates = true>
+struct BooleanStateWithSetVector : public BooleanState {
+ bool contains(const Ty &Elem) const { return Set.contains(Elem); }
+ bool insert(const Ty &Elem) {
+ if (InsertInvalidates)
+ BooleanState::indicatePessimisticFixpoint();
+ return Set.insert(Elem);
+ }
+
+ const Ty &operator[](int Idx) const { return Set[Idx]; }
+ bool operator==(const BooleanStateWithSetVector &RHS) const {
+ return BooleanState::operator==(RHS) && Set == RHS.Set;
+ }
+ bool operator!=(const BooleanStateWithSetVector &RHS) const {
+ return !(*this == RHS);
+ }
+
+ bool empty() const { return Set.empty(); }
+ size_t size() const { return Set.size(); }
+
+ /// "Clamp" this state with \p RHS.
+ BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
+ BooleanState::operator^=(RHS);
+ Set.insert(RHS.Set.begin(), RHS.Set.end());
+ return *this;
+ }
+
+private:
+ /// A set to keep track of elements.
+ SetVector<Ty> Set;
+
+public:
+ typename decltype(Set)::iterator begin() { return Set.begin(); }
+ typename decltype(Set)::iterator end() { return Set.end(); }
+ typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
+ typename decltype(Set)::const_iterator end() const { return Set.end(); }
+};
+
+template <typename Ty, bool InsertInvalidates = true>
+using BooleanStateWithPtrSetVector =
+ BooleanStateWithSetVector<Ty *, InsertInvalidates>;
+
+struct KernelInfoState : AbstractState {
+ /// Flag to track if we reached a fixpoint.
+ bool IsAtFixpoint = false;
+
+ /// The parallel regions (identified by the outlined parallel functions) that
+ /// can be reached from the associated function.
+ BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
+ ReachedKnownParallelRegions;
+
+ /// State to track what parallel region we might reach.
+ BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
+
+ /// State to track if we are in SPMD-mode, assumed or know, and why we decided
+ /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
+ /// false.
+ BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
+
+ /// The __kmpc_target_init call in this kernel, if any. If we find more than
+ /// one we abort as the kernel is malformed.
+ CallBase *KernelInitCB = nullptr;
+
+ /// The constant kernel environement as taken from and passed to
+ /// __kmpc_target_init.
+ ConstantStruct *KernelEnvC = nullptr;
+
+ /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
+ /// one we abort as the kernel is malformed.
+ CallBase *KernelDeinitCB = nullptr;
+
+ /// Flag to indicate if the associated function is a kernel entry.
+ bool IsKernelEntry = false;
+
+ /// State to track what kernel entries can reach the associated function.
+ BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
+
+ /// State to indicate if we can track parallel level of the associated
+ /// function. We will give up tracking if we encounter unknown caller or the
+ /// caller is __kmpc_parallel_51.
+ BooleanStateWithSetVector<uint8_t> ParallelLevels;
+
+ /// Flag that indicates if the kernel has nested Parallelism
+ bool NestedParallelism = false;
+
+ /// Abstract State interface
+ ///{
+
+ KernelInfoState() = default;
+ KernelInfoState(bool BestState) {
+ if (!BestState)
+ indicatePessimisticFixpoint();
+ }
+
+ /// See AbstractState::isValidState(...)
+ bool isValidState() const override { return true; }
+
+ /// See AbstractState::isAtFixpoint(...)
+ bool isAtFixpoint() const override { return IsAtFixpoint; }
+
+ /// See AbstractState::indicatePessimisticFixpoint(...)
+ ChangeStatus indicatePessimisticFixpoint() override {
+ IsAtFixpoint = true;
+ ParallelLevels.indicatePessimisticFixpoint();
+ ReachingKernelEntries.indicatePessimisticFixpoint();
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+ ReachedKnownParallelRegions.indicatePessimisticFixpoint();
+ ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
+ NestedParallelism = true;
+ return ChangeStatus::CHANGED;
+ }
+
+ /// See AbstractState::indicateOptimisticFixpoint(...)
+ ChangeStatus indicateOptimisticFixpoint() override {
+ IsAtFixpoint = true;
+ ParallelLevels.indicateOptimisticFixpoint();
+ ReachingKernelEntries.indicateOptimisticFixpoint();
+ SPMDCompatibilityTracker.indicateOptimisticFixpoint();
+ ReachedKnownParallelRegions.indicateOptimisticFixpoint();
+ ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
+ return ChangeStatus::UNCHANGED;
+ }
+
+ /// Return the assumed state
+ KernelInfoState &getAssumed() { return *this; }
+ const KernelInfoState &getAssumed() const { return *this; }
+
+ bool operator==(const KernelInfoState &RHS) const {
+ if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
+ return false;
+ if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
+ return false;
+ if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
+ return false;
+ if (ReachingKernelEntries != RHS.ReachingKernelEntries)
+ return false;
+ if (ParallelLevels != RHS.ParallelLevels)
+ return false;
+ if (NestedParallelism != RHS.NestedParallelism)
+ return false;
+ return true;
+ }
+
+ /// Returns true if this kernel contains any OpenMP parallel regions.
+ bool mayContainParallelRegion() {
+ return !ReachedKnownParallelRegions.empty() ||
+ !ReachedUnknownParallelRegions.empty();
+ }
+
+ /// Return empty set as the best state of potential values.
+ static KernelInfoState getBestState() { return KernelInfoState(true); }
+
+ static KernelInfoState getBestState(KernelInfoState &KIS) {
+ return getBestState();
+ }
+
+ /// Return full set as the worst state of potential values.
+ static KernelInfoState getWorstState() { return KernelInfoState(false); }
+
+ /// "Clamp" this state with \p KIS.
+ KernelInfoState operator^=(const KernelInfoState &KIS) {
+ // Do not merge two different _init and _deinit call sites.
+ if (KIS.KernelInitCB) {
+ if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
+ llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+ "assumptions.");
+ KernelInitCB = KIS.KernelInitCB;
+ }
+ if (KIS.KernelDeinitCB) {
+ if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
+ llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+ "assumptions.");
+ KernelDeinitCB = KIS.KernelDeinitCB;
+ }
+ if (KIS.KernelEnvC) {
+ if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
+ llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+ "assumptions.");
+ KernelEnvC = KIS.KernelEnvC;
+ }
+ SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
+ ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
+ ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
+ NestedParallelism |= KIS.NestedParallelism;
+ return *this;
+ }
+
+ KernelInfoState operator&=(const KernelInfoState &KIS) {
+ return (*this ^= KIS);
+ }
+};
+
+struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
+ using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
+ AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+ /// Public getter for ReachingKernelEntries
+ virtual BooleanStateWithPtrSetVector<Function, false>
+ getReachingKernels() = 0;
+};
+
} // end namespace llvm
#endif // LLVM_TRANSFORMS_IPO_OPENMPOPT_H
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 5b42f215fb40ca..4dd32f236c03f3 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -638,200 +638,6 @@ struct OMPInformationCache : public InformationCache {
bool OpenMPPostLink = false;
};
-template <typename Ty, bool InsertInvalidates = true>
-struct BooleanStateWithSetVector : public BooleanState {
- bool contains(const Ty &Elem) const { return Set.contains(Elem); }
- bool insert(const Ty &Elem) {
- if (InsertInvalidates)
- BooleanState::indicatePessimisticFixpoint();
- return Set.insert(Elem);
- }
-
- const Ty &operator[](int Idx) const { return Set[Idx]; }
- bool operator==(const BooleanStateWithSetVector &RHS) const {
- return BooleanState::operator==(RHS) && Set == RHS.Set;
- }
- bool operator!=(const BooleanStateWithSetVector &RHS) const {
- return !(*this == RHS);
- }
-
- bool empty() const { return Set.empty(); }
- size_t size() const { return Set.size(); }
-
- /// "Clamp" this state with \p RHS.
- BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
- BooleanState::operator^=(RHS);
- Set.insert(RHS.Set.begin(), RHS.Set.end());
- return *this;
- }
-
-private:
- /// A set to keep track of elements.
- SetVector<Ty> Set;
-
-public:
- typename decltype(Set)::iterator begin() { return Set.begin(); }
- typename decltype(Set)::iterator end() { return Set.end(); }
- typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
- typename decltype(Set)::const_iterator end() const { return Set.end(); }
-};
-
-template <typename Ty, bool InsertInvalidates = true>
-using BooleanStateWithPtrSetVector =
- BooleanStateWithSetVector<Ty *, InsertInvalidates>;
-
-struct KernelInfoState : AbstractState {
- /// Flag to track if we reached a fixpoint.
- bool IsAtFixpoint = false;
-
- /// The parallel regions (identified by the outlined parallel functions) that
- /// can be reached from the associated function.
- BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
- ReachedKnownParallelRegions;
-
- /// State to track what parallel region we might reach.
- BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
-
- /// State to track if we are in SPMD-mode, assumed or know, and why we decided
- /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
- /// false.
- BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
-
- /// The __kmpc_target_init call in this kernel, if any. If we find more than
- /// one we abort as the kernel is malformed.
- CallBase *KernelInitCB = nullptr;
-
- /// The constant kernel environement as taken from and passed to
- /// __kmpc_target_init.
- ConstantStruct *KernelEnvC = nullptr;
-
- /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
- /// one we abort as the kernel is malformed.
- CallBase *KernelDeinitCB = nullptr;
-
- /// Flag to indicate if the associated function is a kernel entry.
- bool IsKernelEntry = false;
-
- /// State to track what kernel entries can reach the associated function.
- BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
-
- /// State to indicate if we can track parallel level of the associated
- /// function. We will give up tracking if we encounter unknown caller or the
- /// caller is __kmpc_parallel_51.
- BooleanStateWithSetVector<uint8_t> ParallelLevels;
-
- /// Flag that indicates if the kernel has nested Parallelism
- bool NestedParallelism = false;
-
- /// Abstract State interface
- ///{
-
- KernelInfoState() = default;
- KernelInfoState(bool BestState) {
- if (!BestState)
- indicatePessimisticFixpoint();
- }
-
- /// See AbstractState::isValidState(...)
- bool isValidState() const override { return true; }
-
- /// See AbstractState::isAtFixpoint(...)
- bool isAtFixpoint() const override { return IsAtFixpoint; }
-
- /// See AbstractState::indicatePessimisticFixpoint(...)
- ChangeStatus indicatePessimisticFixpoint() override {
- IsAtFixpoint = true;
- ParallelLevels.indicatePessimisticFixpoint();
- ReachingKernelEntries.indicatePessimisticFixpoint();
- SPMDCompatibilityTracker.indicatePessimisticFixpoint();
- ReachedKnownParallelRegions.indicatePessimisticFixpoint();
- ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
- NestedParallelism = true;
- return ChangeStatus::CHANGED;
- }
-
- /// See AbstractState::indicateOptimisticFixpoint(...)
- ChangeStatus indicateOptimisticFixpoint() override {
- IsAtFixpoint = true;
- ParallelLevels.indicateOptimisticFixpoint();
- ReachingKernelEntries.indicateOptimisticFixpoint();
- SPMDCompatibilityTracker.indicateOptimisticFixpoint();
- ReachedKnownParallelRegions.indicateOptimisticFixpoint();
- ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
- return ChangeStatus::UNCHANGED;
- }
-
- /// Return the assumed state
- KernelInfoState &getAssumed() { return *this; }
- const KernelInfoState &getAssumed() const { return *this; }
-
- bool operator==(const KernelInfoState &RHS) const {
- if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
- return false;
- if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
- return false;
- if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
- return false;
- if (ReachingKernelEntries != RHS.ReachingKernelEntries)
- return false;
- if (ParallelLevels != RHS.ParallelLevels)
- return false;
- if (NestedParallelism != RHS.NestedParallelism)
- return false;
- return true;
- }
-
- /// Returns true if this kernel contains any OpenMP parallel regions.
- bool mayContainParallelRegion() {
- return !ReachedKnownParallelRegions.empty() ||
- !ReachedUnknownParallelRegions.empty();
- }
-
- /// Return empty set as the best state of potential values.
- static KernelInfoState getBestState() { return KernelInfoState(true); }
-
- static KernelInfoState getBestState(KernelInfoState &KIS) {
- return getBestState();
- }
-
- /// Return full set as the worst state of potential values.
- static KernelInfoState getWorstState() { return KernelInfoState(false); }
-
- /// "Clamp" this state with \p KIS.
- KernelInfoState operator^=(const KernelInfoState &KIS) {
- // Do not merge two different _init and _deinit call sites.
- if (KIS.KernelInitCB) {
- if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
- llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
- "assumptions.");
- KernelInitCB = KIS.KernelInitCB;
- }
- if (KIS.KernelDeinitCB) {
- if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
- llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
- "assumptions.");
- KernelDeinitCB = KIS.KernelDeinitCB;
- }
- if (KIS.KernelEnvC) {
- if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
- llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
- "assumptions.");
- KernelEnvC = KIS.KernelEnvC;
- }
- SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
- ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
- ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
- NestedParallelism |= KIS.NestedParallelism;
- return *this;
- }
-
- KernelInfoState operator&=(const KernelInfoState &KIS) {
- return (*this ^= KIS);
- }
-
- ///}
-};
-
/// Used to map the values physically (in the IR) stored in an offload
/// array, to a vector in memory.
struct OffloadArray {
@@ -3596,9 +3402,9 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
unsigned SharedMemoryUsed = 0;
};
-struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
- using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
- AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+struct AAKernelInfoImpl : AAKernelInfo {
+ AAKernelInfoImpl(const IRPosition &IRP, Attributor &A)
+ : AAKernelInfo(IRP, A) {}
/// The callee value is tracked beyond a simple stripPointerCasts, so we allow
/// unknown callees.
@@ -3635,27 +3441,34 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
}
/// Create an abstract attribute biew for the position \p IRP.
- static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
+ static AAKernelInfoImpl &createForPosition(const IRPosition &IRP,
+ Attributor &A);
/// See AbstractAttribute::getName()
- const std::string getName() const override { return "AAKernelInfo"; }
+ const std::string getName() const override { return "AAKernelInfoImpl"; }
/// See AbstractAttribute::getIdAddr()
const char *getIdAddr() const override { return &ID; }
- /// This function should return true if the type of the \p AA is AAKernelInfo
+ /// This function should return true if the type of the \p AA is
+ /// AAKernelInfoImpl
static bool classof(const AbstractAttribute *AA) {
return (AA->getIdAddr() == &ID);
}
static const char ID;
+
+ /// Return the ReachingKernelEntries
+ BooleanStateWithPtrSetVector<Function, false> getReachingKernels() override {
+ return ReachingKernelEntries;
+ }
};
/// The function kernel info abstract attribute, basically, what can we say
/// about a function with regards to the KernelInfoState.
-struct AAKernelInfoFunction : AAKernelInfo {
+struct AAKernelInfoFunction : AAKernelInfoImpl {
AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
- : AAKernelInfo(IRP, A) {}
+ : AAKernelInfoImpl(IRP, A) {}
SmallPtrSet<Instruction *, 4> GuardedInstructions;
@@ -3815,7 +3628,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
};
// Add a dependence to ensure updates if the state changes.
- auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
+ auto AddDependence = [](Attributor &A, const AAKernelInfoImpl *KI,
const AbstractAttribute *QueryingAA) {
if (QueryingAA) {
A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
@@ -4119,10 +3932,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
for (Instruction *GuardedI : SPMDCompatibilityTracker) {
BasicBlock *BB = GuardedI->getParent();
- auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
+ auto *CalleeAA = A.lookupAAFor<AAKernelInfoImpl>(
IRPosition::function(*GuardedI->getFunction()), nullptr,
DepClassTy::NONE);
- assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
+ assert(CalleeAA != nullptr && "Expected Callee AAKernelInfoImpl");
auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
// Continue if instruction is already guarded.
if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
@@ -4724,7 +4537,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
// we cannot fix the internal spmd-zation state either.
int SPMD = 0, Generic = 0;
for (auto *Kernel : ReachingKernelEntries) {
- auto *CBAA = A.getAAFor<AAKernelInfo>(
+ auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
*this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
CBAA->SPMDCompatibilityTracker.isAssumed())
@@ -4745,7 +4558,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
bool AllSPMDStatesWereFixed = true;
auto CheckCallInst = [&](Instruction &I) {
auto &CB = cast<CallBase>(I);
- auto *CBAA = A.getAAFor<AAKernelInfo>(
+ auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
*this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
if (!CBAA)
return false;
@@ -4794,7 +4607,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
assert(Caller && "Caller is nullptr");
- auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
+ auto *CAA = A.getOrCreateAAFor<AAKernelInfoImpl>(
IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
if (CAA && CAA->ReachingKernelEntries.isValidState()) {
ReachingKernelEntries ^= CAA->ReachingKernelEntries;
@@ -4826,7 +4639,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
assert(Caller && "Caller is nullptr");
auto *CAA =
- A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
+ A.getOrCreateAAFor<AAKernelInfoImpl>(IRPosition::function(*Caller));
if (CAA && CAA->ParallelLevels.isValidState()) {
// Any function that is called by `__kmpc_parallel_51` will not be
// folded as the parallel level in the function is updated. In order to
@@ -4861,13 +4674,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
/// The call site kernel info abstract attribute, basically, what can we say
/// about a call site with regards to the KernelInfoState. For now this simply
/// forwards the information from the callee.
-struct AAKernelInfoCallSite : AAKernelInfo {
+struct AAKernelInfoCallSite : AAKernelInfoImpl {
AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
- : AAKernelInfo(IRP, A) {}
+ : AAKernelInfoImpl(IRP, A) {}
/// See AbstractAttribute::initialize(...).
void initialize(Attributor &A) override {
- AAKernelInfo::initialize(A);
+ AAKernelInfoImpl::initialize(A);
CallBase &CB = cast<CallBase>(getAssociatedValue());
auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
@@ -4889,8 +4702,9 @@ struct AAKernelInfoCallSite : AAKernelInfo {
// Next we check if we know the callee. If it is a known OpenMP function
// we will handle them explicitly in the switch below. If it is not, we
- // will use an AAKernelInfo object on the callee to gather information and
- // merge that into the current state. The latter happens in the updateImpl.
+ // will use an AAKernelInfoImpl object on the callee to gather information
+ // and merge that into the current state. The latter happens in the
+ // updateImpl.
auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
@@ -5054,12 +4868,12 @@ struct AAKernelInfoCallSite : AAKernelInfo {
auto CheckCallee = [&](Function *F, int NumCallees) {
const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
- // If F is not a runtime function, propagate the AAKernelInfo of the
+ // If F is not a runtime function, propagate the AAKernelInfoImpl of the
// callee.
if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
const IRPosition &FnPos = IRPosition::function(*F);
auto *FnAA =
- A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
+ A.getAAFor<AAKernelInfoImpl>(*this, FnPos, DepClassTy::REQUIRED);
if (!FnAA)
return indicatePessimisticFixpoint();
if (getState() == FnAA->getState())
@@ -5148,7 +4962,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
ReachedKnownParallelRegions.insert(&CB);
/// Check nested parallelism
- auto *FnAA = A.getAAFor<AAKernelInfo>(
+ auto *FnAA = A.getAAFor<AAKernelInfoImpl>(
*this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
!FnAA->ReachedKnownParallelRegions.empty() ||
@@ -5305,7 +5119,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
- auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+ auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
if (!CallerKernelInfoAA ||
@@ -5313,8 +5127,8 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
return indicatePessimisticFixpoint();
for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
- auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
- DepClassTy::REQUIRED);
+ auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+ DepClassTy::REQUIRED);
if (!AA || !AA->isValidState()) {
SimplifiedValue = nullptr;
@@ -5366,7 +5180,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
ChangeStatus foldParallelLevel(Attributor &A) {
std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
- auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+ auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
if (!CallerKernelInfoAA ||
@@ -5385,8 +5199,8 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
- auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
- DepClassTy::REQUIRED);
+ auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+ DepClassTy::REQUIRED);
if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
return indicatePessimisticFixpoint();
@@ -5429,7 +5243,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
int32_t CurrentAttrValue = -1;
std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
- auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+ auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
if (!CallerKernelInfoAA ||
@@ -5486,12 +5300,12 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
return;
if (IsModulePass) {
- // Ensure we create the AAKernelInfo AAs first and without triggering an
+ // Ensure we create the AAKernelInfoImpl AAs first and without triggering an
// update. This will make sure we register all value simplification
// callbacks before any other AA has the chance to create an AAValueSimplify
// or similar.
auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
- A.getOrCreateAAFor<AAKernelInfo>(
+ A.getOrCreateAAFor<AAKernelInfoImpl>(
IRPosition::function(Kernel), /* QueryingAA */ nullptr,
DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false);
@@ -5594,7 +5408,7 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
}
const char AAICVTracker::ID = 0;
-const char AAKernelInfo::ID = 0;
+const char AAKernelInfoImpl::ID = 0;
const char AAExecutionDomain::ID = 0;
const char AAHeapToShared::ID = 0;
const char AAFoldRuntimeCall::ID = 0;
@@ -5667,9 +5481,9 @@ AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
return *AA;
}
-AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
- Attributor &A) {
- AAKernelInfo *AA = nullptr;
+AAKernelInfoImpl &AAKernelInfoImpl::createForPosition(const IRPosition &IRP,
+ Attributor &A) {
+ AAKernelInfoImpl *AA = nullptr;
switch (IRP.getPositionKind()) {
case IRPosition::IRP_INVALID:
case IRPosition::IRP_FLOAT:
>From f8ae5184867db8cf2776165c0697bd4fd33c1c31 Mon Sep 17 00:00:00 2001
From: Nafis Mustakin <nmust004 at ucr.edu>
Date: Mon, 13 Nov 2023 09:49:17 -0800
Subject: [PATCH 2/4] Move functions to AAKernelInfo
---
llvm/include/llvm/Transforms/IPO/OpenMPOpt.h | 18 +++++++
llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 56 +++++++-------------
2 files changed, 37 insertions(+), 37 deletions(-)
diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 049c1458bb9d63..04d40ab00439e1 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -262,6 +262,24 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
/// Public getter for ReachingKernelEntries
virtual BooleanStateWithPtrSetVector<Function, false>
getReachingKernels() = 0;
+
+ /// Create an abstract attribute biew for the position \p IRP.
+ static AAKernelInfo &createForPosition(const IRPosition &IRP,
+ Attributor &A);
+
+ /// This function should return true if the type of the \p AA is AAKernelInfo
+ static bool classof(const AbstractAttribute *AA) {
+ return (AA->getIdAddr() == &ID);
+ }
+
+ /// See AbstractAttribute::getName()
+ const std::string getName() const override { return "AAKernelInfo"; }
+
+ /// See AbstractAttribute::getIdAddr()
+ const char *getIdAddr() const override { return &ID; }
+
+ static const char ID;
+
};
} // end namespace llvm
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 4dd32f236c03f3..1cf89d63c4e998 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3440,24 +3440,6 @@ struct AAKernelInfoImpl : AAKernelInfo {
", NestedPar: " + (NestedParallelism ? "yes" : "no");
}
- /// Create an abstract attribute biew for the position \p IRP.
- static AAKernelInfoImpl &createForPosition(const IRPosition &IRP,
- Attributor &A);
-
- /// See AbstractAttribute::getName()
- const std::string getName() const override { return "AAKernelInfoImpl"; }
-
- /// See AbstractAttribute::getIdAddr()
- const char *getIdAddr() const override { return &ID; }
-
- /// This function should return true if the type of the \p AA is
- /// AAKernelInfoImpl
- static bool classof(const AbstractAttribute *AA) {
- return (AA->getIdAddr() == &ID);
- }
-
- static const char ID;
-
/// Return the ReachingKernelEntries
BooleanStateWithPtrSetVector<Function, false> getReachingKernels() override {
return ReachingKernelEntries;
@@ -3932,10 +3914,10 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
for (Instruction *GuardedI : SPMDCompatibilityTracker) {
BasicBlock *BB = GuardedI->getParent();
- auto *CalleeAA = A.lookupAAFor<AAKernelInfoImpl>(
+ auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
IRPosition::function(*GuardedI->getFunction()), nullptr,
DepClassTy::NONE);
- assert(CalleeAA != nullptr && "Expected Callee AAKernelInfoImpl");
+ assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
// Continue if instruction is already guarded.
if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
@@ -4537,7 +4519,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
// we cannot fix the internal spmd-zation state either.
int SPMD = 0, Generic = 0;
for (auto *Kernel : ReachingKernelEntries) {
- auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
+ auto *CBAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
CBAA->SPMDCompatibilityTracker.isAssumed())
@@ -4558,7 +4540,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
bool AllSPMDStatesWereFixed = true;
auto CheckCallInst = [&](Instruction &I) {
auto &CB = cast<CallBase>(I);
- auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
+ auto *CBAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
if (!CBAA)
return false;
@@ -4607,7 +4589,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
assert(Caller && "Caller is nullptr");
- auto *CAA = A.getOrCreateAAFor<AAKernelInfoImpl>(
+ auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
if (CAA && CAA->ReachingKernelEntries.isValidState()) {
ReachingKernelEntries ^= CAA->ReachingKernelEntries;
@@ -4639,7 +4621,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
assert(Caller && "Caller is nullptr");
auto *CAA =
- A.getOrCreateAAFor<AAKernelInfoImpl>(IRPosition::function(*Caller));
+ A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
if (CAA && CAA->ParallelLevels.isValidState()) {
// Any function that is called by `__kmpc_parallel_51` will not be
// folded as the parallel level in the function is updated. In order to
@@ -4868,12 +4850,12 @@ struct AAKernelInfoCallSite : AAKernelInfoImpl {
auto CheckCallee = [&](Function *F, int NumCallees) {
const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
- // If F is not a runtime function, propagate the AAKernelInfoImpl of the
+ // If F is not a runtime function, propagate the AAKernelInfo of the
// callee.
if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
const IRPosition &FnPos = IRPosition::function(*F);
auto *FnAA =
- A.getAAFor<AAKernelInfoImpl>(*this, FnPos, DepClassTy::REQUIRED);
+ A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
if (!FnAA)
return indicatePessimisticFixpoint();
if (getState() == FnAA->getState())
@@ -4962,7 +4944,7 @@ struct AAKernelInfoCallSite : AAKernelInfoImpl {
ReachedKnownParallelRegions.insert(&CB);
/// Check nested parallelism
- auto *FnAA = A.getAAFor<AAKernelInfoImpl>(
+ auto *FnAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
!FnAA->ReachedKnownParallelRegions.empty() ||
@@ -5119,7 +5101,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
- auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
+ auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
if (!CallerKernelInfoAA ||
@@ -5127,7 +5109,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
return indicatePessimisticFixpoint();
for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
- auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+ auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
DepClassTy::REQUIRED);
if (!AA || !AA->isValidState()) {
@@ -5180,7 +5162,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
ChangeStatus foldParallelLevel(Attributor &A) {
std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
- auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
+ auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
if (!CallerKernelInfoAA ||
@@ -5199,7 +5181,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
- auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+ auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
DepClassTy::REQUIRED);
if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
return indicatePessimisticFixpoint();
@@ -5243,7 +5225,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
int32_t CurrentAttrValue = -1;
std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
- auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
+ auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
if (!CallerKernelInfoAA ||
@@ -5300,12 +5282,12 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
return;
if (IsModulePass) {
- // Ensure we create the AAKernelInfoImpl AAs first and without triggering an
+ // Ensure we create the AAKernelInfo AAs first and without triggering an
// update. This will make sure we register all value simplification
// callbacks before any other AA has the chance to create an AAValueSimplify
// or similar.
auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
- A.getOrCreateAAFor<AAKernelInfoImpl>(
+ A.getOrCreateAAFor<AAKernelInfo>(
IRPosition::function(Kernel), /* QueryingAA */ nullptr,
DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false);
@@ -5408,7 +5390,7 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
}
const char AAICVTracker::ID = 0;
-const char AAKernelInfoImpl::ID = 0;
+const char AAKernelInfo::ID = 0;
const char AAExecutionDomain::ID = 0;
const char AAHeapToShared::ID = 0;
const char AAFoldRuntimeCall::ID = 0;
@@ -5481,9 +5463,9 @@ AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
return *AA;
}
-AAKernelInfoImpl &AAKernelInfoImpl::createForPosition(const IRPosition &IRP,
+AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
Attributor &A) {
- AAKernelInfoImpl *AA = nullptr;
+ AAKernelInfo *AA = nullptr;
switch (IRP.getPositionKind()) {
case IRPosition::IRP_INVALID:
case IRPosition::IRP_FLOAT:
>From e805aaf48fddaed8ce7832a5960b6f00e394faa9 Mon Sep 17 00:00:00 2001
From: Nafis Mustakin <nmust004 at ucr.edu>
Date: Mon, 13 Nov 2023 09:54:49 -0800
Subject: [PATCH 3/4] Fix formatting issues
---
llvm/include/llvm/Transforms/IPO/OpenMPOpt.h | 6 ++----
llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 6 +++---
2 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 04d40ab00439e1..5bef03c0a1fce3 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -262,10 +262,9 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
/// Public getter for ReachingKernelEntries
virtual BooleanStateWithPtrSetVector<Function, false>
getReachingKernels() = 0;
-
+
/// Create an abstract attribute biew for the position \p IRP.
- static AAKernelInfo &createForPosition(const IRPosition &IRP,
- Attributor &A);
+ static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
/// This function should return true if the type of the \p AA is AAKernelInfo
static bool classof(const AbstractAttribute *AA) {
@@ -279,7 +278,6 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
const char *getIdAddr() const override { return &ID; }
static const char ID;
-
};
} // end namespace llvm
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 1cf89d63c4e998..0494b2d717711c 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -5110,7 +5110,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
- DepClassTy::REQUIRED);
+ DepClassTy::REQUIRED);
if (!AA || !AA->isValidState()) {
SimplifiedValue = nullptr;
@@ -5182,7 +5182,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
- DepClassTy::REQUIRED);
+ DepClassTy::REQUIRED);
if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
return indicatePessimisticFixpoint();
@@ -5464,7 +5464,7 @@ AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
}
AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
- Attributor &A) {
+ Attributor &A) {
AAKernelInfo *AA = nullptr;
switch (IRP.getPositionKind()) {
case IRPosition::IRP_INVALID:
>From 7611103ea9ad54baa217d1fa64e6e7f8ced80e91 Mon Sep 17 00:00:00 2001
From: Johannes Doerfert <johannesdoerfert at gmail.com>
Date: Mon, 11 Dec 2023 14:15:59 -0800
Subject: [PATCH 4/4] Update OpenMPOpt.cpp
---
llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 0494b2d717711c..1e404e34942cd8 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3610,7 +3610,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
};
// Add a dependence to ensure updates if the state changes.
- auto AddDependence = [](Attributor &A, const AAKernelInfoImpl *KI,
+ auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
const AbstractAttribute *QueryingAA) {
if (QueryingAA) {
A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
More information about the llvm-commits
mailing list