[llvm] [OpenMP] Move KernelInfoState and AAKernelInfo to OpenMPOpt.h (PR #71878)

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 11 14:19:30 PST 2023


https://github.com/jdoerfert updated https://github.com/llvm/llvm-project/pull/71878

>From ce96bdb47d3274f3c11a20b6572c89bf4f467c48 Mon Sep 17 00:00:00 2001
From: Nafis Mustakin <nmust004 at ucr.edu>
Date: Thu, 9 Nov 2023 15:28:57 -0800
Subject: [PATCH 1/5] Move KernelInfoState and AAKernelInfo to
 llvm/include/llvm/Transforms/IPO/OpenMPOpt.h

---
 llvm/include/llvm/Transforms/IPO/OpenMPOpt.h | 202 ++++++++++++++
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp        | 274 +++----------------
 2 files changed, 246 insertions(+), 230 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 2499c2bbccf455..049c1458bb9d63 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -12,6 +12,7 @@
 #include "llvm/Analysis/CGSCCPassManager.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/Transforms/IPO/Attributor.h"
 
 namespace llvm {
 
@@ -62,6 +63,207 @@ class OpenMPOptCGSCCPass : public PassInfoMixin<OpenMPOptCGSCCPass> {
   const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
 };
 
+template <typename Ty, bool InsertInvalidates = true>
+struct BooleanStateWithSetVector : public BooleanState {
+  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
+  bool insert(const Ty &Elem) {
+    if (InsertInvalidates)
+      BooleanState::indicatePessimisticFixpoint();
+    return Set.insert(Elem);
+  }
+
+  const Ty &operator[](int Idx) const { return Set[Idx]; }
+  bool operator==(const BooleanStateWithSetVector &RHS) const {
+    return BooleanState::operator==(RHS) && Set == RHS.Set;
+  }
+  bool operator!=(const BooleanStateWithSetVector &RHS) const {
+    return !(*this == RHS);
+  }
+
+  bool empty() const { return Set.empty(); }
+  size_t size() const { return Set.size(); }
+
+  /// "Clamp" this state with \p RHS.
+  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
+    BooleanState::operator^=(RHS);
+    Set.insert(RHS.Set.begin(), RHS.Set.end());
+    return *this;
+  }
+
+private:
+  /// A set to keep track of elements.
+  SetVector<Ty> Set;
+
+public:
+  typename decltype(Set)::iterator begin() { return Set.begin(); }
+  typename decltype(Set)::iterator end() { return Set.end(); }
+  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
+  typename decltype(Set)::const_iterator end() const { return Set.end(); }
+};
+
+template <typename Ty, bool InsertInvalidates = true>
+using BooleanStateWithPtrSetVector =
+    BooleanStateWithSetVector<Ty *, InsertInvalidates>;
+
+struct KernelInfoState : AbstractState {
+  /// Flag to track if we reached a fixpoint.
+  bool IsAtFixpoint = false;
+
+  /// The parallel regions (identified by the outlined parallel functions) that
+  /// can be reached from the associated function.
+  BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
+      ReachedKnownParallelRegions;
+
+  /// State to track what parallel region we might reach.
+  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
+
+  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
+  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
+  /// false.
+  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
+
+  /// The __kmpc_target_init call in this kernel, if any. If we find more than
+  /// one we abort as the kernel is malformed.
+  CallBase *KernelInitCB = nullptr;
+
+  /// The constant kernel environement as taken from and passed to
+  /// __kmpc_target_init.
+  ConstantStruct *KernelEnvC = nullptr;
+
+  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
+  /// one we abort as the kernel is malformed.
+  CallBase *KernelDeinitCB = nullptr;
+
+  /// Flag to indicate if the associated function is a kernel entry.
+  bool IsKernelEntry = false;
+
+  /// State to track what kernel entries can reach the associated function.
+  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
+
+  /// State to indicate if we can track parallel level of the associated
+  /// function. We will give up tracking if we encounter unknown caller or the
+  /// caller is __kmpc_parallel_51.
+  BooleanStateWithSetVector<uint8_t> ParallelLevels;
+
+  /// Flag that indicates if the kernel has nested Parallelism
+  bool NestedParallelism = false;
+
+  /// Abstract State interface
+  ///{
+
+  KernelInfoState() = default;
+  KernelInfoState(bool BestState) {
+    if (!BestState)
+      indicatePessimisticFixpoint();
+  }
+
+  /// See AbstractState::isValidState(...)
+  bool isValidState() const override { return true; }
+
+  /// See AbstractState::isAtFixpoint(...)
+  bool isAtFixpoint() const override { return IsAtFixpoint; }
+
+  /// See AbstractState::indicatePessimisticFixpoint(...)
+  ChangeStatus indicatePessimisticFixpoint() override {
+    IsAtFixpoint = true;
+    ParallelLevels.indicatePessimisticFixpoint();
+    ReachingKernelEntries.indicatePessimisticFixpoint();
+    SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+    ReachedKnownParallelRegions.indicatePessimisticFixpoint();
+    ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
+    NestedParallelism = true;
+    return ChangeStatus::CHANGED;
+  }
+
+  /// See AbstractState::indicateOptimisticFixpoint(...)
+  ChangeStatus indicateOptimisticFixpoint() override {
+    IsAtFixpoint = true;
+    ParallelLevels.indicateOptimisticFixpoint();
+    ReachingKernelEntries.indicateOptimisticFixpoint();
+    SPMDCompatibilityTracker.indicateOptimisticFixpoint();
+    ReachedKnownParallelRegions.indicateOptimisticFixpoint();
+    ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
+    return ChangeStatus::UNCHANGED;
+  }
+
+  /// Return the assumed state
+  KernelInfoState &getAssumed() { return *this; }
+  const KernelInfoState &getAssumed() const { return *this; }
+
+  bool operator==(const KernelInfoState &RHS) const {
+    if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
+      return false;
+    if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
+      return false;
+    if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
+      return false;
+    if (ReachingKernelEntries != RHS.ReachingKernelEntries)
+      return false;
+    if (ParallelLevels != RHS.ParallelLevels)
+      return false;
+    if (NestedParallelism != RHS.NestedParallelism)
+      return false;
+    return true;
+  }
+
+  /// Returns true if this kernel contains any OpenMP parallel regions.
+  bool mayContainParallelRegion() {
+    return !ReachedKnownParallelRegions.empty() ||
+           !ReachedUnknownParallelRegions.empty();
+  }
+
+  /// Return empty set as the best state of potential values.
+  static KernelInfoState getBestState() { return KernelInfoState(true); }
+
+  static KernelInfoState getBestState(KernelInfoState &KIS) {
+    return getBestState();
+  }
+
+  /// Return full set as the worst state of potential values.
+  static KernelInfoState getWorstState() { return KernelInfoState(false); }
+
+  /// "Clamp" this state with \p KIS.
+  KernelInfoState operator^=(const KernelInfoState &KIS) {
+    // Do not merge two different _init and _deinit call sites.
+    if (KIS.KernelInitCB) {
+      if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelInitCB = KIS.KernelInitCB;
+    }
+    if (KIS.KernelDeinitCB) {
+      if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelDeinitCB = KIS.KernelDeinitCB;
+    }
+    if (KIS.KernelEnvC) {
+      if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelEnvC = KIS.KernelEnvC;
+    }
+    SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
+    ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
+    ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
+    NestedParallelism |= KIS.NestedParallelism;
+    return *this;
+  }
+
+  KernelInfoState operator&=(const KernelInfoState &KIS) {
+    return (*this ^= KIS);
+  }
+};
+
+struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
+  using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
+  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+  /// Public getter for ReachingKernelEntries
+  virtual BooleanStateWithPtrSetVector<Function, false>
+  getReachingKernels() = 0;
+};
+
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_IPO_OPENMPOPT_H
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 5b42f215fb40ca..4dd32f236c03f3 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -638,200 +638,6 @@ struct OMPInformationCache : public InformationCache {
   bool OpenMPPostLink = false;
 };
 
-template <typename Ty, bool InsertInvalidates = true>
-struct BooleanStateWithSetVector : public BooleanState {
-  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
-  bool insert(const Ty &Elem) {
-    if (InsertInvalidates)
-      BooleanState::indicatePessimisticFixpoint();
-    return Set.insert(Elem);
-  }
-
-  const Ty &operator[](int Idx) const { return Set[Idx]; }
-  bool operator==(const BooleanStateWithSetVector &RHS) const {
-    return BooleanState::operator==(RHS) && Set == RHS.Set;
-  }
-  bool operator!=(const BooleanStateWithSetVector &RHS) const {
-    return !(*this == RHS);
-  }
-
-  bool empty() const { return Set.empty(); }
-  size_t size() const { return Set.size(); }
-
-  /// "Clamp" this state with \p RHS.
-  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
-    BooleanState::operator^=(RHS);
-    Set.insert(RHS.Set.begin(), RHS.Set.end());
-    return *this;
-  }
-
-private:
-  /// A set to keep track of elements.
-  SetVector<Ty> Set;
-
-public:
-  typename decltype(Set)::iterator begin() { return Set.begin(); }
-  typename decltype(Set)::iterator end() { return Set.end(); }
-  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
-  typename decltype(Set)::const_iterator end() const { return Set.end(); }
-};
-
-template <typename Ty, bool InsertInvalidates = true>
-using BooleanStateWithPtrSetVector =
-    BooleanStateWithSetVector<Ty *, InsertInvalidates>;
-
-struct KernelInfoState : AbstractState {
-  /// Flag to track if we reached a fixpoint.
-  bool IsAtFixpoint = false;
-
-  /// The parallel regions (identified by the outlined parallel functions) that
-  /// can be reached from the associated function.
-  BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
-      ReachedKnownParallelRegions;
-
-  /// State to track what parallel region we might reach.
-  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
-
-  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
-  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
-  /// false.
-  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
-
-  /// The __kmpc_target_init call in this kernel, if any. If we find more than
-  /// one we abort as the kernel is malformed.
-  CallBase *KernelInitCB = nullptr;
-
-  /// The constant kernel environement as taken from and passed to
-  /// __kmpc_target_init.
-  ConstantStruct *KernelEnvC = nullptr;
-
-  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
-  /// one we abort as the kernel is malformed.
-  CallBase *KernelDeinitCB = nullptr;
-
-  /// Flag to indicate if the associated function is a kernel entry.
-  bool IsKernelEntry = false;
-
-  /// State to track what kernel entries can reach the associated function.
-  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
-
-  /// State to indicate if we can track parallel level of the associated
-  /// function. We will give up tracking if we encounter unknown caller or the
-  /// caller is __kmpc_parallel_51.
-  BooleanStateWithSetVector<uint8_t> ParallelLevels;
-
-  /// Flag that indicates if the kernel has nested Parallelism
-  bool NestedParallelism = false;
-
-  /// Abstract State interface
-  ///{
-
-  KernelInfoState() = default;
-  KernelInfoState(bool BestState) {
-    if (!BestState)
-      indicatePessimisticFixpoint();
-  }
-
-  /// See AbstractState::isValidState(...)
-  bool isValidState() const override { return true; }
-
-  /// See AbstractState::isAtFixpoint(...)
-  bool isAtFixpoint() const override { return IsAtFixpoint; }
-
-  /// See AbstractState::indicatePessimisticFixpoint(...)
-  ChangeStatus indicatePessimisticFixpoint() override {
-    IsAtFixpoint = true;
-    ParallelLevels.indicatePessimisticFixpoint();
-    ReachingKernelEntries.indicatePessimisticFixpoint();
-    SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-    ReachedKnownParallelRegions.indicatePessimisticFixpoint();
-    ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
-    NestedParallelism = true;
-    return ChangeStatus::CHANGED;
-  }
-
-  /// See AbstractState::indicateOptimisticFixpoint(...)
-  ChangeStatus indicateOptimisticFixpoint() override {
-    IsAtFixpoint = true;
-    ParallelLevels.indicateOptimisticFixpoint();
-    ReachingKernelEntries.indicateOptimisticFixpoint();
-    SPMDCompatibilityTracker.indicateOptimisticFixpoint();
-    ReachedKnownParallelRegions.indicateOptimisticFixpoint();
-    ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
-    return ChangeStatus::UNCHANGED;
-  }
-
-  /// Return the assumed state
-  KernelInfoState &getAssumed() { return *this; }
-  const KernelInfoState &getAssumed() const { return *this; }
-
-  bool operator==(const KernelInfoState &RHS) const {
-    if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
-      return false;
-    if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
-      return false;
-    if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
-      return false;
-    if (ReachingKernelEntries != RHS.ReachingKernelEntries)
-      return false;
-    if (ParallelLevels != RHS.ParallelLevels)
-      return false;
-    if (NestedParallelism != RHS.NestedParallelism)
-      return false;
-    return true;
-  }
-
-  /// Returns true if this kernel contains any OpenMP parallel regions.
-  bool mayContainParallelRegion() {
-    return !ReachedKnownParallelRegions.empty() ||
-           !ReachedUnknownParallelRegions.empty();
-  }
-
-  /// Return empty set as the best state of potential values.
-  static KernelInfoState getBestState() { return KernelInfoState(true); }
-
-  static KernelInfoState getBestState(KernelInfoState &KIS) {
-    return getBestState();
-  }
-
-  /// Return full set as the worst state of potential values.
-  static KernelInfoState getWorstState() { return KernelInfoState(false); }
-
-  /// "Clamp" this state with \p KIS.
-  KernelInfoState operator^=(const KernelInfoState &KIS) {
-    // Do not merge two different _init and _deinit call sites.
-    if (KIS.KernelInitCB) {
-      if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelInitCB = KIS.KernelInitCB;
-    }
-    if (KIS.KernelDeinitCB) {
-      if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelDeinitCB = KIS.KernelDeinitCB;
-    }
-    if (KIS.KernelEnvC) {
-      if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelEnvC = KIS.KernelEnvC;
-    }
-    SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
-    ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
-    ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
-    NestedParallelism |= KIS.NestedParallelism;
-    return *this;
-  }
-
-  KernelInfoState operator&=(const KernelInfoState &KIS) {
-    return (*this ^= KIS);
-  }
-
-  ///}
-};
-
 /// Used to map the values physically (in the IR) stored in an offload
 /// array, to a vector in memory.
 struct OffloadArray {
@@ -3596,9 +3402,9 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
   unsigned SharedMemoryUsed = 0;
 };
 
-struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
-  using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
-  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+struct AAKernelInfoImpl : AAKernelInfo {
+  AAKernelInfoImpl(const IRPosition &IRP, Attributor &A)
+      : AAKernelInfo(IRP, A) {}
 
   /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
   /// unknown callees.
@@ -3635,27 +3441,34 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
   }
 
   /// Create an abstract attribute biew for the position \p IRP.
-  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
+  static AAKernelInfoImpl &createForPosition(const IRPosition &IRP,
+                                             Attributor &A);
 
   /// See AbstractAttribute::getName()
-  const std::string getName() const override { return "AAKernelInfo"; }
+  const std::string getName() const override { return "AAKernelInfoImpl"; }
 
   /// See AbstractAttribute::getIdAddr()
   const char *getIdAddr() const override { return &ID; }
 
-  /// This function should return true if the type of the \p AA is AAKernelInfo
+  /// This function should return true if the type of the \p AA is
+  /// AAKernelInfoImpl
   static bool classof(const AbstractAttribute *AA) {
     return (AA->getIdAddr() == &ID);
   }
 
   static const char ID;
+
+  /// Return the ReachingKernelEntries
+  BooleanStateWithPtrSetVector<Function, false> getReachingKernels() override {
+    return ReachingKernelEntries;
+  }
 };
 
 /// The function kernel info abstract attribute, basically, what can we say
 /// about a function with regards to the KernelInfoState.
-struct AAKernelInfoFunction : AAKernelInfo {
+struct AAKernelInfoFunction : AAKernelInfoImpl {
   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
-      : AAKernelInfo(IRP, A) {}
+      : AAKernelInfoImpl(IRP, A) {}
 
   SmallPtrSet<Instruction *, 4> GuardedInstructions;
 
@@ -3815,7 +3628,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     };
 
     // Add a dependence to ensure updates if the state changes.
-    auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
+    auto AddDependence = [](Attributor &A, const AAKernelInfoImpl *KI,
                             const AbstractAttribute *QueryingAA) {
       if (QueryingAA) {
         A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
@@ -4119,10 +3932,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
 
     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
       BasicBlock *BB = GuardedI->getParent();
-      auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
+      auto *CalleeAA = A.lookupAAFor<AAKernelInfoImpl>(
           IRPosition::function(*GuardedI->getFunction()), nullptr,
           DepClassTy::NONE);
-      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
+      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfoImpl");
       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
       // Continue if instruction is already guarded.
       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
@@ -4724,7 +4537,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
           // we cannot fix the internal spmd-zation state either.
           int SPMD = 0, Generic = 0;
           for (auto *Kernel : ReachingKernelEntries) {
-            auto *CBAA = A.getAAFor<AAKernelInfo>(
+            auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
                 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
             if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
                 CBAA->SPMDCompatibilityTracker.isAssumed())
@@ -4745,7 +4558,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
     bool AllSPMDStatesWereFixed = true;
     auto CheckCallInst = [&](Instruction &I) {
       auto &CB = cast<CallBase>(I);
-      auto *CBAA = A.getAAFor<AAKernelInfo>(
+      auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
       if (!CBAA)
         return false;
@@ -4794,7 +4607,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
 
       assert(Caller && "Caller is nullptr");
 
-      auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
+      auto *CAA = A.getOrCreateAAFor<AAKernelInfoImpl>(
           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
       if (CAA && CAA->ReachingKernelEntries.isValidState()) {
         ReachingKernelEntries ^= CAA->ReachingKernelEntries;
@@ -4826,7 +4639,7 @@ struct AAKernelInfoFunction : AAKernelInfo {
       assert(Caller && "Caller is nullptr");
 
       auto *CAA =
-          A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
+          A.getOrCreateAAFor<AAKernelInfoImpl>(IRPosition::function(*Caller));
       if (CAA && CAA->ParallelLevels.isValidState()) {
         // Any function that is called by `__kmpc_parallel_51` will not be
         // folded as the parallel level in the function is updated. In order to
@@ -4861,13 +4674,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
 /// The call site kernel info abstract attribute, basically, what can we say
 /// about a call site with regards to the KernelInfoState. For now this simply
 /// forwards the information from the callee.
-struct AAKernelInfoCallSite : AAKernelInfo {
+struct AAKernelInfoCallSite : AAKernelInfoImpl {
   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
-      : AAKernelInfo(IRP, A) {}
+      : AAKernelInfoImpl(IRP, A) {}
 
   /// See AbstractAttribute::initialize(...).
   void initialize(Attributor &A) override {
-    AAKernelInfo::initialize(A);
+    AAKernelInfoImpl::initialize(A);
 
     CallBase &CB = cast<CallBase>(getAssociatedValue());
     auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
@@ -4889,8 +4702,9 @@ struct AAKernelInfoCallSite : AAKernelInfo {
 
     // Next we check if we know the callee. If it is a known OpenMP function
     // we will handle them explicitly in the switch below. If it is not, we
-    // will use an AAKernelInfo object on the callee to gather information and
-    // merge that into the current state. The latter happens in the updateImpl.
+    // will use an AAKernelInfoImpl object on the callee to gather information
+    // and merge that into the current state. The latter happens in the
+    // updateImpl.
     auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
@@ -5054,12 +4868,12 @@ struct AAKernelInfoCallSite : AAKernelInfo {
     auto CheckCallee = [&](Function *F, int NumCallees) {
       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
 
-      // If F is not a runtime function, propagate the AAKernelInfo of the
+      // If F is not a runtime function, propagate the AAKernelInfoImpl of the
       // callee.
       if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
         const IRPosition &FnPos = IRPosition::function(*F);
         auto *FnAA =
-            A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
+            A.getAAFor<AAKernelInfoImpl>(*this, FnPos, DepClassTy::REQUIRED);
         if (!FnAA)
           return indicatePessimisticFixpoint();
         if (getState() == FnAA->getState())
@@ -5148,7 +4962,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
 
     ReachedKnownParallelRegions.insert(&CB);
     /// Check nested parallelism
-    auto *FnAA = A.getAAFor<AAKernelInfo>(
+    auto *FnAA = A.getAAFor<AAKernelInfoImpl>(
         *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
     NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
                          !FnAA->ReachedKnownParallelRegions.empty() ||
@@ -5305,7 +5119,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
 
     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
-    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
 
     if (!CallerKernelInfoAA ||
@@ -5313,8 +5127,8 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
       return indicatePessimisticFixpoint();
 
     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
-      auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
-                                          DepClassTy::REQUIRED);
+      auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+                                              DepClassTy::REQUIRED);
 
       if (!AA || !AA->isValidState()) {
         SimplifiedValue = nullptr;
@@ -5366,7 +5180,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
   ChangeStatus foldParallelLevel(Attributor &A) {
     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
 
-    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
 
     if (!CallerKernelInfoAA ||
@@ -5385,8 +5199,8 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
-      auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
-                                          DepClassTy::REQUIRED);
+      auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+                                              DepClassTy::REQUIRED);
       if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
         return indicatePessimisticFixpoint();
 
@@ -5429,7 +5243,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
     int32_t CurrentAttrValue = -1;
     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
 
-    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
 
     if (!CallerKernelInfoAA ||
@@ -5486,12 +5300,12 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
     return;
 
   if (IsModulePass) {
-    // Ensure we create the AAKernelInfo AAs first and without triggering an
+    // Ensure we create the AAKernelInfoImpl AAs first and without triggering an
     // update. This will make sure we register all value simplification
     // callbacks before any other AA has the chance to create an AAValueSimplify
     // or similar.
     auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
-      A.getOrCreateAAFor<AAKernelInfo>(
+      A.getOrCreateAAFor<AAKernelInfoImpl>(
           IRPosition::function(Kernel), /* QueryingAA */ nullptr,
           DepClassTy::NONE, /* ForceUpdate */ false,
           /* UpdateAfterInit */ false);
@@ -5594,7 +5408,7 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
 }
 
 const char AAICVTracker::ID = 0;
-const char AAKernelInfo::ID = 0;
+const char AAKernelInfoImpl::ID = 0;
 const char AAExecutionDomain::ID = 0;
 const char AAHeapToShared::ID = 0;
 const char AAFoldRuntimeCall::ID = 0;
@@ -5667,9 +5481,9 @@ AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
   return *AA;
 }
 
-AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
-                                              Attributor &A) {
-  AAKernelInfo *AA = nullptr;
+AAKernelInfoImpl &AAKernelInfoImpl::createForPosition(const IRPosition &IRP,
+                                                      Attributor &A) {
+  AAKernelInfoImpl *AA = nullptr;
   switch (IRP.getPositionKind()) {
   case IRPosition::IRP_INVALID:
   case IRPosition::IRP_FLOAT:

>From f8ae5184867db8cf2776165c0697bd4fd33c1c31 Mon Sep 17 00:00:00 2001
From: Nafis Mustakin <nmust004 at ucr.edu>
Date: Mon, 13 Nov 2023 09:49:17 -0800
Subject: [PATCH 2/5] Move functions to AAKernelInfo

---
 llvm/include/llvm/Transforms/IPO/OpenMPOpt.h | 18 +++++++
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp        | 56 +++++++-------------
 2 files changed, 37 insertions(+), 37 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 049c1458bb9d63..04d40ab00439e1 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -262,6 +262,24 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
   /// Public getter for ReachingKernelEntries
   virtual BooleanStateWithPtrSetVector<Function, false>
   getReachingKernels() = 0;
+  
+  /// Create an abstract attribute biew for the position \p IRP.
+  static AAKernelInfo &createForPosition(const IRPosition &IRP,
+                                             Attributor &A);
+
+  /// This function should return true if the type of the \p AA is AAKernelInfo
+  static bool classof(const AbstractAttribute *AA) {
+    return (AA->getIdAddr() == &ID);
+  }
+
+  /// See AbstractAttribute::getName()
+  const std::string getName() const override { return "AAKernelInfo"; }
+
+  /// See AbstractAttribute::getIdAddr()
+  const char *getIdAddr() const override { return &ID; }
+
+  static const char ID;
+
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 4dd32f236c03f3..1cf89d63c4e998 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3440,24 +3440,6 @@ struct AAKernelInfoImpl : AAKernelInfo {
            ", NestedPar: " + (NestedParallelism ? "yes" : "no");
   }
 
-  /// Create an abstract attribute biew for the position \p IRP.
-  static AAKernelInfoImpl &createForPosition(const IRPosition &IRP,
-                                             Attributor &A);
-
-  /// See AbstractAttribute::getName()
-  const std::string getName() const override { return "AAKernelInfoImpl"; }
-
-  /// See AbstractAttribute::getIdAddr()
-  const char *getIdAddr() const override { return &ID; }
-
-  /// This function should return true if the type of the \p AA is
-  /// AAKernelInfoImpl
-  static bool classof(const AbstractAttribute *AA) {
-    return (AA->getIdAddr() == &ID);
-  }
-
-  static const char ID;
-
   /// Return the ReachingKernelEntries
   BooleanStateWithPtrSetVector<Function, false> getReachingKernels() override {
     return ReachingKernelEntries;
@@ -3932,10 +3914,10 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
 
     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
       BasicBlock *BB = GuardedI->getParent();
-      auto *CalleeAA = A.lookupAAFor<AAKernelInfoImpl>(
+      auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
           IRPosition::function(*GuardedI->getFunction()), nullptr,
           DepClassTy::NONE);
-      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfoImpl");
+      assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
       // Continue if instruction is already guarded.
       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
@@ -4537,7 +4519,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
           // we cannot fix the internal spmd-zation state either.
           int SPMD = 0, Generic = 0;
           for (auto *Kernel : ReachingKernelEntries) {
-            auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
+            auto *CBAA = A.getAAFor<AAKernelInfo>(
                 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
             if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
                 CBAA->SPMDCompatibilityTracker.isAssumed())
@@ -4558,7 +4540,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
     bool AllSPMDStatesWereFixed = true;
     auto CheckCallInst = [&](Instruction &I) {
       auto &CB = cast<CallBase>(I);
-      auto *CBAA = A.getAAFor<AAKernelInfoImpl>(
+      auto *CBAA = A.getAAFor<AAKernelInfo>(
           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
       if (!CBAA)
         return false;
@@ -4607,7 +4589,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
 
       assert(Caller && "Caller is nullptr");
 
-      auto *CAA = A.getOrCreateAAFor<AAKernelInfoImpl>(
+      auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
       if (CAA && CAA->ReachingKernelEntries.isValidState()) {
         ReachingKernelEntries ^= CAA->ReachingKernelEntries;
@@ -4639,7 +4621,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
       assert(Caller && "Caller is nullptr");
 
       auto *CAA =
-          A.getOrCreateAAFor<AAKernelInfoImpl>(IRPosition::function(*Caller));
+          A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
       if (CAA && CAA->ParallelLevels.isValidState()) {
         // Any function that is called by `__kmpc_parallel_51` will not be
         // folded as the parallel level in the function is updated. In order to
@@ -4868,12 +4850,12 @@ struct AAKernelInfoCallSite : AAKernelInfoImpl {
     auto CheckCallee = [&](Function *F, int NumCallees) {
       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
 
-      // If F is not a runtime function, propagate the AAKernelInfoImpl of the
+      // If F is not a runtime function, propagate the AAKernelInfo of the
       // callee.
       if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
         const IRPosition &FnPos = IRPosition::function(*F);
         auto *FnAA =
-            A.getAAFor<AAKernelInfoImpl>(*this, FnPos, DepClassTy::REQUIRED);
+            A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
         if (!FnAA)
           return indicatePessimisticFixpoint();
         if (getState() == FnAA->getState())
@@ -4962,7 +4944,7 @@ struct AAKernelInfoCallSite : AAKernelInfoImpl {
 
     ReachedKnownParallelRegions.insert(&CB);
     /// Check nested parallelism
-    auto *FnAA = A.getAAFor<AAKernelInfoImpl>(
+    auto *FnAA = A.getAAFor<AAKernelInfo>(
         *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
     NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
                          !FnAA->ReachedKnownParallelRegions.empty() ||
@@ -5119,7 +5101,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
 
     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
-    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
+    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
 
     if (!CallerKernelInfoAA ||
@@ -5127,7 +5109,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
       return indicatePessimisticFixpoint();
 
     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
-      auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+      auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
                                               DepClassTy::REQUIRED);
 
       if (!AA || !AA->isValidState()) {
@@ -5180,7 +5162,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
   ChangeStatus foldParallelLevel(Attributor &A) {
     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
 
-    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
+    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
 
     if (!CallerKernelInfoAA ||
@@ -5199,7 +5181,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
-      auto *AA = A.getAAFor<AAKernelInfoImpl>(*this, IRPosition::function(*K),
+      auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
                                               DepClassTy::REQUIRED);
       if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
         return indicatePessimisticFixpoint();
@@ -5243,7 +5225,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
     int32_t CurrentAttrValue = -1;
     std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
 
-    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfoImpl>(
+    auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
 
     if (!CallerKernelInfoAA ||
@@ -5300,12 +5282,12 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
     return;
 
   if (IsModulePass) {
-    // Ensure we create the AAKernelInfoImpl AAs first and without triggering an
+    // Ensure we create the AAKernelInfo AAs first and without triggering an
     // update. This will make sure we register all value simplification
     // callbacks before any other AA has the chance to create an AAValueSimplify
     // or similar.
     auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
-      A.getOrCreateAAFor<AAKernelInfoImpl>(
+      A.getOrCreateAAFor<AAKernelInfo>(
           IRPosition::function(Kernel), /* QueryingAA */ nullptr,
           DepClassTy::NONE, /* ForceUpdate */ false,
           /* UpdateAfterInit */ false);
@@ -5408,7 +5390,7 @@ void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
 }
 
 const char AAICVTracker::ID = 0;
-const char AAKernelInfoImpl::ID = 0;
+const char AAKernelInfo::ID = 0;
 const char AAExecutionDomain::ID = 0;
 const char AAHeapToShared::ID = 0;
 const char AAFoldRuntimeCall::ID = 0;
@@ -5481,9 +5463,9 @@ AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
   return *AA;
 }
 
-AAKernelInfoImpl &AAKernelInfoImpl::createForPosition(const IRPosition &IRP,
+AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
                                                       Attributor &A) {
-  AAKernelInfoImpl *AA = nullptr;
+  AAKernelInfo *AA = nullptr;
   switch (IRP.getPositionKind()) {
   case IRPosition::IRP_INVALID:
   case IRPosition::IRP_FLOAT:

>From e805aaf48fddaed8ce7832a5960b6f00e394faa9 Mon Sep 17 00:00:00 2001
From: Nafis Mustakin <nmust004 at ucr.edu>
Date: Mon, 13 Nov 2023 09:54:49 -0800
Subject: [PATCH 3/5] Fix formatting issues

---
 llvm/include/llvm/Transforms/IPO/OpenMPOpt.h | 6 ++----
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp        | 6 +++---
 2 files changed, 5 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 04d40ab00439e1..5bef03c0a1fce3 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -262,10 +262,9 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
   /// Public getter for ReachingKernelEntries
   virtual BooleanStateWithPtrSetVector<Function, false>
   getReachingKernels() = 0;
-  
+
   /// Create an abstract attribute biew for the position \p IRP.
-  static AAKernelInfo &createForPosition(const IRPosition &IRP,
-                                             Attributor &A);
+  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
 
   /// This function should return true if the type of the \p AA is AAKernelInfo
   static bool classof(const AbstractAttribute *AA) {
@@ -279,7 +278,6 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
   const char *getIdAddr() const override { return &ID; }
 
   static const char ID;
-
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 1cf89d63c4e998..0494b2d717711c 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -5110,7 +5110,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
 
     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
       auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
-                                              DepClassTy::REQUIRED);
+                                          DepClassTy::REQUIRED);
 
       if (!AA || !AA->isValidState()) {
         SimplifiedValue = nullptr;
@@ -5182,7 +5182,7 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
     for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
       auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
-                                              DepClassTy::REQUIRED);
+                                          DepClassTy::REQUIRED);
       if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
         return indicatePessimisticFixpoint();
 
@@ -5464,7 +5464,7 @@ AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
 }
 
 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
-                                                      Attributor &A) {
+                                              Attributor &A) {
   AAKernelInfo *AA = nullptr;
   switch (IRP.getPositionKind()) {
   case IRPosition::IRP_INVALID:

>From 7611103ea9ad54baa217d1fa64e6e7f8ced80e91 Mon Sep 17 00:00:00 2001
From: Johannes Doerfert <johannesdoerfert at gmail.com>
Date: Mon, 11 Dec 2023 14:15:59 -0800
Subject: [PATCH 4/5] Update OpenMPOpt.cpp

---
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 0494b2d717711c..1e404e34942cd8 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -3610,7 +3610,7 @@ struct AAKernelInfoFunction : AAKernelInfoImpl {
     };
 
     // Add a dependence to ensure updates if the state changes.
-    auto AddDependence = [](Attributor &A, const AAKernelInfoImpl *KI,
+    auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
                             const AbstractAttribute *QueryingAA) {
       if (QueryingAA) {
         A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);

>From 0ac679cd321c4607cad158cfae6f9461386345a8 Mon Sep 17 00:00:00 2001
From: Johannes Doerfert <johannesdoerfert at gmail.com>
Date: Mon, 11 Dec 2023 14:17:49 -0800
Subject: [PATCH 5/5] Update OpenMPOpt.cpp

---
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 1e404e34942cd8..543830f8d81aca 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -4684,7 +4684,7 @@ struct AAKernelInfoCallSite : AAKernelInfoImpl {
 
     // Next we check if we know the callee. If it is a known OpenMP function
     // we will handle them explicitly in the switch below. If it is not, we
-    // will use an AAKernelInfoImpl object on the callee to gather information
+    // will use an AAKernelInfo object on the callee to gather information
     // and merge that into the current state. The latter happens in the
     // updateImpl.
     auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {



More information about the llvm-commits mailing list