[llvm] [OpenMP] Move KernelInfoState and AAKernelInfo to OpenMPOpt.h (PR #71878)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 9 15:46:30 PST 2023


https://github.com/nmustakin created https://github.com/llvm/llvm-project/pull/71878

Current barrier modeling is assumed to access all/any memory and that is propagated to the kernel. To update that behavior, we need to accumulate the memory effects of all reaching kernels into the memory effect of the barrier. To make the ReachingKernelEntries of AAKernelInfo available, `KernelInfoState` and `AAKernelInfo` definitions have been moved to `llvm/include/llvm/Transforms/IPO/OpenMPOpt.h` along with redefining AAKernelinfo in `llvm/lib/Transforms/IPO/OpenMPOpt.cpp` to `struct AAKernelInfoImpl : AAKernelInfo`. 

>From ce96bdb47d3274f3c11a20b6572c89bf4f467c48 Mon Sep 17 00:00:00 2001
From: Nafis Mustakin <nmust004 at ucr.edu>
Date: Thu, 9 Nov 2023 15:28:57 -0800
Subject: [PATCH] Move KernelInfoState and AAKernelInfo to
 llvm/include/llvm/Transforms/IPO/OpenMPOpt.h

---
 llvm/include/llvm/Transforms/IPO/OpenMPOpt.h | 202 ++++++++++++++
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp        | 274 +++----------------
 2 files changed, 246 insertions(+), 230 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 2499c2bbccf4554..049c1458bb9d630 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -12,6 +12,7 @@
 #include "llvm/Analysis/CGSCCPassManager.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/Transforms/IPO/Attributor.h"
 
 namespace llvm {
 
@@ -62,6 +63,207 @@ class OpenMPOptCGSCCPass : public PassInfoMixin<OpenMPOptCGSCCPass> {
   const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
 };
 
+template <typename Ty, bool InsertInvalidates = true>
+struct BooleanStateWithSetVector : public BooleanState {
+  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
+  bool insert(const Ty &Elem) {
+    if (InsertInvalidates)
+      BooleanState::indicatePessimisticFixpoint();
+    return Set.insert(Elem);
+  }
+
+  const Ty &operator[](int Idx) const { return Set[Idx]; }
+  bool operator==(const BooleanStateWithSetVector &RHS) const {
+    return BooleanState::operator==(RHS) && Set == RHS.Set;
+  }
+  bool operator!=(const BooleanStateWithSetVector &RHS) const {
+    return !(*this == RHS);
+  }
+
+  bool empty() const { return Set.empty(); }
+  size_t size() const { return Set.size(); }
+
+  /// "Clamp" this state with \p RHS.
+  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
+    BooleanState::operator^=(RHS);
+    Set.insert(RHS.Set.begin(), RHS.Set.end());
+    return *this;
+  }
+
+private:
+  /// A set to keep track of elements.
+  SetVector<Ty> Set;
+
+public:
+  typename decltype(Set)::iterator begin() { return Set.begin(); }
+  typename decltype(Set)::iterator end() { return Set.end(); }
+  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
+  typename decltype(Set)::const_iterator end() const { return Set.end(); }
+};
+
+template <typename Ty, bool InsertInvalidates = true>
+using BooleanStateWithPtrSetVector =
+    BooleanStateWithSetVector<Ty *, InsertInvalidates>;
+
+struct KernelInfoState : AbstractState {
+  /// Flag to track if we reached a fixpoint.
+  bool IsAtFixpoint = false;
+
+  /// The parallel regions (identified by the outlined parallel functions) that
+  /// can be reached from the associated function.
+  BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
+      ReachedKnownParallelRegions;
+
+  /// State to track what parallel region we might reach.
+  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
+
+  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
+  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
+  /// false.
+  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
+
+  /// The __kmpc_target_init call in this kernel, if any. If we find more than
+  /// one we abort as the kernel is malformed.
+  CallBase *KernelInitCB = nullptr;
+
+  /// The constant kernel environement as taken from and passed to
+  /// __kmpc_target_init.
+  ConstantStruct *KernelEnvC = nullptr;
+
+  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
+  /// one we abort as the kernel is malformed.
+  CallBase *KernelDeinitCB = nullptr;
+
+  /// Flag to indicate if the associated function is a kernel entry.
+  bool IsKernelEntry = false;
+
+  /// State to track what kernel entries can reach the associated function.
+  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
+
+  /// State to indicate if we can track parallel level of the associated
+  /// function. We will give up tracking if we encounter unknown caller or the
+  /// caller is __kmpc_parallel_51.
+  BooleanStateWithSetVector<uint8_t> ParallelLevels;
+
+  /// Flag that indicates if the kernel has nested Parallelism
+  bool NestedParallelism = false;
+
+  /// Abstract State interface
+  ///{
+
+  KernelInfoState() = default;
+  KernelInfoState(bool BestState) {
+    if (!BestState)
+      indicatePessimisticFixpoint();
+  }
+
+  /// See AbstractState::isValidState(...)
+  bool isValidState() const override { return true; }
+
+  /// See AbstractState::isAtFixpoint(...)
+  bool isAtFixpoint() const override { return IsAtFixpoint; }
+
+  /// See AbstractState::indicatePessimisticFixpoint(...)
+  ChangeStatus indicatePessimisticFixpoint() override {
+    IsAtFixpoint = true;
+    ParallelLevels.indicatePessimisticFixpoint();
+    ReachingKernelEntries.indicatePessimisticFixpoint();
+    SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+    ReachedKnownParallelRegions.indicatePessimisticFixpoint();
+    ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
+    NestedParallelism = true;
+    return ChangeStatus::CHANGED;
+  }
+
+  /// See AbstractState::indicateOptimisticFixpoint(...)
+  ChangeStatus indicateOptimisticFixpoint() override {
+    IsAtFixpoint = true;
+    ParallelLevels.indicateOptimisticFixpoint();
+    ReachingKernelEntries.indicateOptimisticFixpoint();
+    SPMDCompatibilityTracker.indicateOptimisticFixpoint();
+    ReachedKnownParallelRegions.indicateOptimisticFixpoint();
+    ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
+    return ChangeStatus::UNCHANGED;
+  }
+
+  /// Return the assumed state
+  KernelInfoState &getAssumed() { return *this; }
+  const KernelInfoState &getAssumed() const { return *this; }
+
+  bool operator==(const KernelInfoState &RHS) const {
+    if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
+      return false;
+    if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
+      return false;
+    if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
+      return false;
+    if (ReachingKernelEntries != RHS.ReachingKernelEntries)
+      return false;
+    if (ParallelLevels != RHS.ParallelLevels)
+      return false;
+    if (NestedParallelism != RHS.NestedParallelism)
+      return false;
+    return true;
+  }
+
+  /// Returns true if this kernel contains any OpenMP parallel regions.
+  bool mayContainParallelRegion() {
+    return !ReachedKnownParallelRegions.empty() ||
+           !ReachedUnknownParallelRegions.empty();
+  }
+
+  /// Return empty set as the best state of potential values.
+  static KernelInfoState getBestState() { return KernelInfoState(true); }
+
+  static KernelInfoState getBestState(KernelInfoState &KIS) {
+    return getBestState();
+  }
+
+  /// Return full set as the worst state of potential values.
+  static KernelInfoState getWorstState() { return KernelInfoState(false); }
+
+  /// "Clamp" this state with \p KIS.
+  KernelInfoState operator^=(const KernelInfoState &KIS) {
+    // Do not merge two different _init and _deinit call sites.
+    if (KIS.KernelInitCB) {
+      if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelInitCB = KIS.KernelInitCB;
+    }
+    if (KIS.KernelDeinitCB) {
+      if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelDeinitCB = KIS.KernelDeinitCB;
+    }
+    if (KIS.KernelEnvC) {
+      if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelEnvC = KIS.KernelEnvC;
+    }
+    SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
+    ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
+    ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
+    NestedParallelism |= KIS.NestedParallelism;
+    return *this;
+  }
+
+  KernelInfoState operator&=(const KernelInfoState &KIS) {
+    return (*this ^= KIS);
+  }
+};
+
+struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
+  using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
+  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+  /// Public getter for ReachingKernelEntries
+  virtual BooleanStateWithPtrSetVector<Function, false>
+  getReachingKernels() = 0;
+};
+
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_IPO_OPENMPOPT_H
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 5b42f215fb40ca0..4dd32f236c03f3d 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -638,200 +638,6 @@ struct OMPInformationCache : public InformationCache {
   bool OpenMPPostLink = false;
 };
 
-template <typename Ty, bool InsertInvalidates = true>
-struct BooleanStateWithSetVector : public BooleanState {
-  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
-  bool insert(const Ty &Elem) {
-    if (InsertInvalidates)
-      BooleanState::indicatePessimisticFixpoint();
-    return Set.insert(Elem);
-  }
-
-  const Ty &operator[](int Idx) const { return Set[Idx]; }
-  bool operator==(const BooleanStateWithSetVector &RHS) const {
-    return BooleanState::operator==(RHS) && Set == RHS.Set;
-  }
-  bool operator!=(const BooleanStateWithSetVector &RHS) const {
-    return !(*this == RHS);
-  }
-
-  bool empty() const { return Set.empty(); }
-  size_t size() const { return Set.size(); }
-
-  /// "Clamp" this state with \p RHS.
-  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
-    BooleanState::operator^=(RHS);
-    Set.insert(RHS.Set.begin(), RHS.Set.end());
-    return *this;
-  }
-
-private:
-  /// A set to keep track of elements.
-  SetVector<Ty> Set;
-
-public:
-  typename decltype(Set)::iterator begin() { return Set.begin(); }
-  typename decltype(Set)::iterator end() { return Set.end(); }
-  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
-  typename decltype(Set)::const_iterator end() const { return Set.end(); }
-};
-
-template <typename Ty, bool InsertInvalidates = true>
-using BooleanStateWithPtrSetVector =
-    BooleanStateWithSetVector<Ty *, InsertInvalidates>;
-
-struct KernelInfoState : AbstractState {
-  /// Flag to track if we reached a fixpoint.
-  bool IsAtFixpoint = false;
-
-  /// The parallel regions (identified by the outlined parallel functions) that
-  /// can be reached from the associated function.
-  BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
-      ReachedKnownParallelRegions;
-
-  /// State to track what parallel region we might reach.
-  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
-
-  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
-  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
-  /// false.
-  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
-
-  /// The __kmpc_target_init call in this kernel, if any. If we find more than
-  /// one we abort as the kernel is malformed.
-  CallBase *KernelInitCB = nullptr;
-
-  /// The constant kernel environement as taken from and passed to
-  /// __kmpc_target_init.
-  ConstantStruct *KernelEnvC = nullptr;
-
-  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
-  /// one we abort as the kernel is malformed.
-  CallBase *KernelDeinitCB = nullptr;
-
-  /// Flag to indicate if the associated function is a kernel entry.
-  bool IsKernelEntry = false;
-
-  /// State to track what kernel entries can reach the associated function.
-  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
-
-  /// State to indicate if we can track parallel level of the associated
-  /// function. We will give up tracking if we encounter unknown caller or the
-  /// caller is __kmpc_parallel_51.
-  BooleanStateWithSetVector<uint8_t> ParallelLevels;
-
-  /// Flag that indicates if the kernel has nested Parallelism
-  bool NestedParallelism = false;
-
-  /// Abstract State interface
-  ///{
-
-  KernelInfoState() = default;
-  KernelInfoState(bool BestState) {
-    if (!BestState)
-      indicatePessimisticFixpoint();
-  }
-
-  /// See AbstractState::isValidState(...)
-  bool isValidState() const override { return true; }
-
-  /// See AbstractState::isAtFixpoint(...)
-  bool isAtFixpoint() const override { return IsAtFixpoint; }
-
-  /// See AbstractState::indicatePessimisticFixpoint(...)
-  ChangeStatus indicatePessimisticFixpoint() override {
-    IsAtFixpoint = true;
-    ParallelLevels.indicatePessimisticFixpoint();
-    ReachingKernelEntries.indicatePessimisticFixpoint();
-    SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-    ReachedKnownParallelRegions.indicatePessimisticFixpoint();
-    ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
-    NestedParallelism = true;
-    return ChangeStatus::CHANGED;
-  }
-
-  /// See AbstractState::indicateOptimisticFixpoint(...)
-  ChangeStatus indicateOptimisticFixpoint() override {
-    IsAtFixpoint = true;
-    ParallelLevels.indicateOptimisticFixpoint();
-    ReachingKernelEntries.indicateOptimisticFixpoint();
-    SPMDCompatibilityTracker.indicateOptimisticFixpoint();
-    ReachedKnownParallelRegions.indicateOptimisticFixpoint();
-    ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
-    return ChangeStatus::UNCHANGED;
-  }
-
-  /// Return the assumed state
-  KernelInfoState &getAssumed() { return *this; }
-  const KernelInfoState &getAssumed() const { return *this; }
-
-  bool operator==(const KernelInfoState &RHS) const {
-    if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
-      return false;
-    if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
-      return false;
-    if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
-      return false;
-    if (ReachingKernelEntries != RHS.ReachingKernelEntries)
-      return false;
-    if (ParallelLevels != RHS.ParallelLevels)
-      return false;
-    if (NestedParallelism != RHS.NestedParallelism)
-      return false;
-    return true;
-  }
-
-  /// Returns true if this kernel contains any OpenMP parallel regions.
-  bool mayContainParallelRegion() {
-    return !ReachedKnownParallelRegions.empty() ||
-           !ReachedUnknownParallelRegions.empty();
-  }
-
-  /// Return empty set as the best state of potential values.
-  static KernelInfoState getBestState() { return KernelInfoState(true); }
-
-  static KernelInfoState getBestState(KernelInfoState &KIS) {
-    return getBestState();
-  }
-
-  /// Return full set as the worst state of potential values.
-  static KernelInfoState getWorstState() { return KernelInfoState(false); }
-
-  /// "Clamp" this state with \p KIS.
-  KernelInfoState operator^=(const KernelInfoState &KIS) {
-    // Do not merge two different _init and _deinit call sites.
-    if (KIS.KernelInitCB) {
-      if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelInitCB = KIS.KernelInitCB;
-    }
-    if (KIS.KernelDeinitCB) {
-      if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelDeinitCB = KIS.KernelDeinitCB;
-    }
-    if (KIS.KernelEnvC) {
-      if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelEnvC = KIS.KernelEnvC;
-    }
-    SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
-    ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
-    ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
-    NestedParallelism |= KIS.NestedParallelism;
-    return *this;
-  }
-
-  KernelInfoState operator&=(const KernelInfoState &KIS) {
-    return (*this ^= KIS);
-  }
-
-  ///}
-};
-
 /// Used to map the values physically (in the IR) stored in an offload
 /// array, to a vector in memory.
 struct OffloadArray {
@@ -3596,9 +3402,9 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
   unsigned SharedMemoryUsed = 0;
 };
 
-struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
-  using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
-  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+struct AAKernelInfoImpl : AAKernelInfo {
+  AAKernelInfoImpl(const IRPosition &IRP, Attributor &A)
+      : AAKernelInfo(IRP, A) {}
 
   /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
   /// unknown callees.
@@ -3635,27 +3441,34 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
   }
 
   /// Create an abstract attribute biew for the position \p IRP.
-  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
+  static AAKernelInfoImpl &createForPosition(const IRPosition &IRP,
+                                             Attributor &A);
 
   /// See AbstractAttribute::getName()
-  const std::string getName() const override { return "AAKernelInfo"; }
+  const std::string getName() const override { return "AAKernelInfoImpl"; }
 
   /// See AbstractAttribute::getIdAddr()
   const char *getIdAddr() const override { return &ID; }
 
-  /// This function should return true if the type of the \p AA is AAKernelInfo
+  /// This function should return true if the type of the \p AA is
+  /// AAKernelInfoImpl
   static bool classof(const AbstractAttribute *AA) {
     return (AA->getIdAddr() == &ID);
   }
 
   static const char ID;
+
+  /// Return the ReachingKernelEntries
+  BooleanStateWithPtrSetVector<Function, false> getReachingKernels() override {
+    return ReachingKernelEntries;
+  }
 };
 
 /// The function kernel info abstract attribute, basically, what can we say
 /// about a function with regards to the KernelInfoState.
-struct AAKernelInfoFunction : AAKernelInfo {
+struct AAKernelInfoFunction : AAKernelInfoImpl {
   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
-      : AAKernelInfo(IRP, A) {}
+      : AAKernelInfoImpl(IRP, A) {}
 
   SmallPtrSet<Instruction *, 4> GuardedInstructions;
 
@@ -3815,7 +3628,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     };
 
     // Add a dependence to ensure updates if the state changes.
-    auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
+    auto AddDependence = [](Attributor &A, const AAKernelInfoImpl *KI,
                             const AbstractAttribute *QueryingAA) {
       if (QueryingAA) {
         A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
@@ -4119,10 +3932,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
 
     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
       BasicBlock *BB = GuardedI->getParent();
-      auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
+      auto *CalleeAA = A.lookupAAFor<AAKernelInfoImpl>(
           IRPosition::function(*GuardedI->getFunction()), nullptr,
           DepClassTy::NONE);
-      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
+      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfoImpl");
       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
       // Continue if instruction is already guarded.
       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
@@ -4724,7 +4537,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
           // we cannot fix the internal spmd-zation state either.
           int SPMD = 0, Generic = 0;
           for (auto *Kernel : ReachingKernelEntries) {
-            auto *CBAA = A.getAAFor<AAKernelInfo>(
+            auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
                 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
             if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
                 CBAA->SPMDCompatibilityTracker.isAssumed())
@@ -4745,7 +4558,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     bool AllSPMDStatesWereFixed = true;
     auto CheckCallInst = [&](Instruction &I) {
       auto &CB = cast<CallBase>(I);
-      auto *CBAA = A.getAAFor<AAKernelInfo>(
+      auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
       if (!CBAA)
         return false;
@@ -4794,7 +4607,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
 
       assert(Caller && "Caller is nullptr");
 
-      auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
+      auto *CAA = A.getOrCreateAAFor<AAKernelInfoImpl>(
           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
       if (CAA && CAA->ReachingKernelEntries.isValidState()) {
         ReachingKernelEntries ^= CAA->ReachingKernelEntries;
@@ -4826,7 +4639,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
       assert(Caller && "Caller is nullptr");
 
       auto *CAA =
-          A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
+          A.getOrCreateAAFor<AAKernelInfoImpl>(IRPosition::function(*Caller));
       if (CAA && CAA->ParallelLevels.isValidState()) {
         // Any function that is called by `__kmpc_parallel_51` will not be
         // folded as the parallel level in the function is updated. In order to
@@ -4861,13 +4674,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
 /// The call site kernel info abstract attribute, basically, what can we say
 /// about a call site with regards to the KernelInfoState. For now this simply
 /// forwards the information from the callee.
-struct AAKernelInfoCallSite : AAKernelInfo {
+struct AAKernelInfoCallSite : AAKernelInfoImpl {
   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
-      : AAKernelInfo(IRP, A) {}
+      : AAKernelInfoImpl(IRP, A) {}
 
   /// See AbstractAttribute::initialize(...).
   void initialize(Attributor &A) override {
-    AAKernelInfo::initialize(A);
+    AAKernelInfoImpl::initialize(A);
 
     CallBase &CB = cast<CallBase>(getAssociatedValue());
     auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
@@ -4889,8 +4702,9 @@ struct AAKernelInfoCallSite : AAKernelInfo {
 
     // Next we check if we know the callee. If it is a known OpenMP function
     // we will handle them explicitly in the switch below. If it is not, we
-    // will use an AAKernelInfo object on the callee to gather information and
-    // merge that into the current state. The latter happens in the updateImpl.
+    // will use an AAKernelInfoImpl object on the callee to gather information
+    // and merge that into the current state. The latter happens in the
+    // updateImpl.
     auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
@@ -5054,12 +4868,12 @@ struct AAKernelInfoCallSite : AAKernelInfo {
     auto CheckCallee = [&](Function *F, int NumCallees) {
       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
 
-      // If F is not a runtime function, propagate the AAKernelInfo of the
+      // If F is not a runtime function, propagate the AAKernelInfoImpl of the
       // callee.
       if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
         const IRPosition &FnPos = IRPosition::function(*F);
         auto *FnAA =
-            A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
+            A.getAAFor<AAKernelInfoImpl>(*this, FnPos, DepClassTy::REQUIRED);
         if (!FnAA)
           return indicatePessimisticFixpoint();
         if (getState() == FnAA->getState())
@@ -5148,7 +4962,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
 
     ReachedKnownParallelRegions.insert(&CB);
     /// Check nested parallelism
-    auto *FnAA = A.getAAFor<AAKernelInfo>(
+    auto *FnAA = A.getAAFor<AAKernelInfoImpl>(
         *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
     NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
                          !FnAA->ReachedKnownParallelRegions.empty() ||
@@ -5305,7 +5119,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
 
     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
-    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
 
     if (!CallerKernelInfoAA ||
@@ -5313,8 +5127,8 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
       return indicatePessimisticFixpoint();
 
     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
-      auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
-                                          DepClassTy::REQUIRED);
+      auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+                                              DepClassTy::REQUIRED);
 
       if (!AA || !AA->isValidState()) {
         SimplifiedValue = nullptr;
@@ -5366,7 +5180,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
   ChangeStatus foldParallelLevel(Attributor &A) {
     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
 
-    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
 
     if (!CallerKernelInfoAA ||
@@ -5385,8 +5199,8 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
-      auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
-                                          DepClassTy::REQUIRED);
+      auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+                                              DepClassTy::REQUIRED);
       if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
         return indicatePessimisticFixpoint();
 
@@ -5429,7 +5243,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
     int32_t CurrentAttrValue = -1;
     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
 
-    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
 
     if (!CallerKernelInfoAA ||
@@ -5486,12 +5300,12 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
     return;
 
   if (IsModulePass) {
-    // Ensure we create the AAKernelInfo AAs first and without triggering an
+    // Ensure we create the AAKernelInfoImpl AAs first and without triggering an
     // update. This will make sure we register all value simplification
     // callbacks before any other AA has the chance to create an AAValueSimplify
     // or similar.
     auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
-      A.getOrCreateAAFor<AAKernelInfo>(
+      A.getOrCreateAAFor<AAKernelInfoImpl>(
           IRPosition::function(Kernel), /* QueryingAA */ nullptr,
           DepClassTy::NONE, /* ForceUpdate */ false,
           /* UpdateAfterInit */ false);
@@ -5594,7 +5408,7 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
 }
 
 const char AAICVTracker::ID = 0;
-const char AAKernelInfo::ID = 0;
+const char AAKernelInfoImpl::ID = 0;
 const char AAExecutionDomain::ID = 0;
 const char AAHeapToShared::ID = 0;
 const char AAFoldRuntimeCall::ID = 0;
@@ -5667,9 +5481,9 @@ AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
   return *AA;
 }
 
-AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
-                                              Attributor &A) {
-  AAKernelInfo *AA = nullptr;
+AAKernelInfoImpl &AAKernelInfoImpl::createForPosition(const IRPosition &IRP,
+                                                      Attributor &A) {
+  AAKernelInfoImpl *AA = nullptr;
   switch (IRP.getPositionKind()) {
   case IRPosition::IRP_INVALID:
   case IRPosition::IRP_FLOAT:



More information about the llvm-commits mailing list