[llvm] [OpenMP] Move KernelInfoState and AAKernelInfo to OpenMPOpt.h (PR #71878)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 23 10:10:06 PST 2024


https://github.com/nmustakin updated https://github.com/llvm/llvm-project/pull/71878

>From 96f80679140e96b58640c0a14e3d16bb57876804 Mon Sep 17 00:00:00 2001
From: Nafis Mustakin <nmust004 at ucr.edu>
Date: Thu, 9 Nov 2023 15:28:57 -0800
Subject: [PATCH] Move KernelInfoState and AAKernelInfo to
 llvm/include/llvm/Transforms/IPO/OpenMPOpt.h

Move functions to AAKernelInfo

Fix formatting issues

Update OpenMPOpt.cpp

Update OpenMPOpt.cpp
---
 llvm/include/llvm/Transforms/IPO/OpenMPOpt.h | 218 +++++++++++++++++
 llvm/lib/Transforms/IPO/OpenMPOpt.cpp        | 232 ++-----------------
 2 files changed, 232 insertions(+), 218 deletions(-)

diff --git a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
index 2499c2bbccf4554..5bef03c0a1fce35 100644
--- a/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
+++ b/llvm/include/llvm/Transforms/IPO/OpenMPOpt.h
@@ -12,6 +12,7 @@
 #include "llvm/Analysis/CGSCCPassManager.h"
 #include "llvm/Analysis/LazyCallGraph.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/Transforms/IPO/Attributor.h"
 
 namespace llvm {
 
@@ -62,6 +63,223 @@ class OpenMPOptCGSCCPass : public PassInfoMixin<OpenMPOptCGSCCPass> {
   const ThinOrFullLTOPhase LTOPhase = ThinOrFullLTOPhase::None;
 };
 
+template <typename Ty, bool InsertInvalidates = true>
+struct BooleanStateWithSetVector : public BooleanState {
+  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
+  bool insert(const Ty &Elem) {
+    if (InsertInvalidates)
+      BooleanState::indicatePessimisticFixpoint();
+    return Set.insert(Elem);
+  }
+
+  const Ty &operator[](int Idx) const { return Set[Idx]; }
+  bool operator==(const BooleanStateWithSetVector &RHS) const {
+    return BooleanState::operator==(RHS) && Set == RHS.Set;
+  }
+  bool operator!=(const BooleanStateWithSetVector &RHS) const {
+    return !(*this == RHS);
+  }
+
+  bool empty() const { return Set.empty(); }
+  size_t size() const { return Set.size(); }
+
+  /// "Clamp" this state with \p RHS.
+  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
+    BooleanState::operator^=(RHS);
+    Set.insert(RHS.Set.begin(), RHS.Set.end());
+    return *this;
+  }
+
+private:
+  /// A set to keep track of elements.
+  SetVector<Ty> Set;
+
+public:
+  typename decltype(Set)::iterator begin() { return Set.begin(); }
+  typename decltype(Set)::iterator end() { return Set.end(); }
+  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
+  typename decltype(Set)::const_iterator end() const { return Set.end(); }
+};
+
+template <typename Ty, bool InsertInvalidates = true>
+using BooleanStateWithPtrSetVector =
+    BooleanStateWithSetVector<Ty *, InsertInvalidates>;
+
+struct KernelInfoState : AbstractState {
+  /// Flag to track if we reached a fixpoint.
+  bool IsAtFixpoint = false;
+
+  /// The parallel regions (identified by the outlined parallel functions) that
+  /// can be reached from the associated function.
+  BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
+      ReachedKnownParallelRegions;
+
+  /// State to track what parallel region we might reach.
+  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
+
+  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
+  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
+  /// false.
+  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
+
+  /// The __kmpc_target_init call in this kernel, if any. If we find more than
+  /// one we abort as the kernel is malformed.
+  CallBase *KernelInitCB = nullptr;
+
+  /// The constant kernel environement as taken from and passed to
+  /// __kmpc_target_init.
+  ConstantStruct *KernelEnvC = nullptr;
+
+  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
+  /// one we abort as the kernel is malformed.
+  CallBase *KernelDeinitCB = nullptr;
+
+  /// Flag to indicate if the associated function is a kernel entry.
+  bool IsKernelEntry = false;
+
+  /// State to track what kernel entries can reach the associated function.
+  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
+
+  /// State to indicate if we can track parallel level of the associated
+  /// function. We will give up tracking if we encounter unknown caller or the
+  /// caller is __kmpc_parallel_51.
+  BooleanStateWithSetVector<uint8_t> ParallelLevels;
+
+  /// Flag that indicates if the kernel has nested Parallelism
+  bool NestedParallelism = false;
+
+  /// Abstract State interface
+  ///{
+
+  KernelInfoState() = default;
+  KernelInfoState(bool BestState) {
+    if (!BestState)
+      indicatePessimisticFixpoint();
+  }
+
+  /// See AbstractState::isValidState(...)
+  bool isValidState() const override { return true; }
+
+  /// See AbstractState::isAtFixpoint(...)
+  bool isAtFixpoint() const override { return IsAtFixpoint; }
+
+  /// See AbstractState::indicatePessimisticFixpoint(...)
+  ChangeStatus indicatePessimisticFixpoint() override {
+    IsAtFixpoint = true;
+    ParallelLevels.indicatePessimisticFixpoint();
+    ReachingKernelEntries.indicatePessimisticFixpoint();
+    SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+    ReachedKnownParallelRegions.indicatePessimisticFixpoint();
+    ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
+    NestedParallelism = true;
+    return ChangeStatus::CHANGED;
+  }
+
+  /// See AbstractState::indicateOptimisticFixpoint(...)
+  ChangeStatus indicateOptimisticFixpoint() override {
+    IsAtFixpoint = true;
+    ParallelLevels.indicateOptimisticFixpoint();
+    ReachingKernelEntries.indicateOptimisticFixpoint();
+    SPMDCompatibilityTracker.indicateOptimisticFixpoint();
+    ReachedKnownParallelRegions.indicateOptimisticFixpoint();
+    ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
+    return ChangeStatus::UNCHANGED;
+  }
+
+  /// Return the assumed state
+  KernelInfoState &getAssumed() { return *this; }
+  const KernelInfoState &getAssumed() const { return *this; }
+
+  bool operator==(const KernelInfoState &RHS) const {
+    if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
+      return false;
+    if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
+      return false;
+    if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
+      return false;
+    if (ReachingKernelEntries != RHS.ReachingKernelEntries)
+      return false;
+    if (ParallelLevels != RHS.ParallelLevels)
+      return false;
+    if (NestedParallelism != RHS.NestedParallelism)
+      return false;
+    return true;
+  }
+
+  /// Returns true if this kernel contains any OpenMP parallel regions.
+  bool mayContainParallelRegion() {
+    return !ReachedKnownParallelRegions.empty() ||
+           !ReachedUnknownParallelRegions.empty();
+  }
+
+  /// Return empty set as the best state of potential values.
+  static KernelInfoState getBestState() { return KernelInfoState(true); }
+
+  static KernelInfoState getBestState(KernelInfoState &KIS) {
+    return getBestState();
+  }
+
+  /// Return full set as the worst state of potential values.
+  static KernelInfoState getWorstState() { return KernelInfoState(false); }
+
+  /// "Clamp" this state with \p KIS.
+  KernelInfoState operator^=(const KernelInfoState &KIS) {
+    // Do not merge two different _init and _deinit call sites.
+    if (KIS.KernelInitCB) {
+      if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelInitCB = KIS.KernelInitCB;
+    }
+    if (KIS.KernelDeinitCB) {
+      if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelDeinitCB = KIS.KernelDeinitCB;
+    }
+    if (KIS.KernelEnvC) {
+      if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
+        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
+                         "assumptions.");
+      KernelEnvC = KIS.KernelEnvC;
+    }
+    SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
+    ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
+    ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
+    NestedParallelism |= KIS.NestedParallelism;
+    return *this;
+  }
+
+  KernelInfoState operator&=(const KernelInfoState &KIS) {
+    return (*this ^= KIS);
+  }
+};
+
+struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
+  using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
+  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+
+  /// Public getter for ReachingKernelEntries
+  virtual BooleanStateWithPtrSetVector<Function, false>
+  getReachingKernels() = 0;
+
+  /// Create an abstract attribute biew for the position \p IRP.
+  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
+
+  /// This function should return true if the type of the \p AA is AAKernelInfo
+  static bool classof(const AbstractAttribute *AA) {
+    return (AA->getIdAddr() == &ID);
+  }
+
+  /// See AbstractAttribute::getName()
+  const std::string getName() const override { return "AAKernelInfo"; }
+
+  /// See AbstractAttribute::getIdAddr()
+  const char *getIdAddr() const override { return &ID; }
+
+  static const char ID;
+};
+
 } // end namespace llvm
 
 #endif // LLVM_TRANSFORMS_IPO_OPENMPOPT_H
diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 4176d561363fbd9..fe3d664d6fdb0bf 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -638,200 +638,6 @@ struct OMPInformationCache : public InformationCache {
   bool OpenMPPostLink = false;
 };
 
-template <typename Ty, bool InsertInvalidates = true>
-struct BooleanStateWithSetVector : public BooleanState {
-  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
-  bool insert(const Ty &Elem) {
-    if (InsertInvalidates)
-      BooleanState::indicatePessimisticFixpoint();
-    return Set.insert(Elem);
-  }
-
-  const Ty &operator[](int Idx) const { return Set[Idx]; }
-  bool operator==(const BooleanStateWithSetVector &RHS) const {
-    return BooleanState::operator==(RHS) && Set == RHS.Set;
-  }
-  bool operator!=(const BooleanStateWithSetVector &RHS) const {
-    return !(*this == RHS);
-  }
-
-  bool empty() const { return Set.empty(); }
-  size_t size() const { return Set.size(); }
-
-  /// "Clamp" this state with \p RHS.
-  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
-    BooleanState::operator^=(RHS);
-    Set.insert(RHS.Set.begin(), RHS.Set.end());
-    return *this;
-  }
-
-private:
-  /// A set to keep track of elements.
-  SetVector<Ty> Set;
-
-public:
-  typename decltype(Set)::iterator begin() { return Set.begin(); }
-  typename decltype(Set)::iterator end() { return Set.end(); }
-  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
-  typename decltype(Set)::const_iterator end() const { return Set.end(); }
-};
-
-template <typename Ty, bool InsertInvalidates = true>
-using BooleanStateWithPtrSetVector =
-    BooleanStateWithSetVector<Ty *, InsertInvalidates>;
-
-struct KernelInfoState : AbstractState {
-  /// Flag to track if we reached a fixpoint.
-  bool IsAtFixpoint = false;
-
-  /// The parallel regions (identified by the outlined parallel functions) that
-  /// can be reached from the associated function.
-  BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
-      ReachedKnownParallelRegions;
-
-  /// State to track what parallel region we might reach.
-  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
-
-  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
-  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
-  /// false.
-  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
-
-  /// The __kmpc_target_init call in this kernel, if any. If we find more than
-  /// one we abort as the kernel is malformed.
-  CallBase *KernelInitCB = nullptr;
-
-  /// The constant kernel environement as taken from and passed to
-  /// __kmpc_target_init.
-  ConstantStruct *KernelEnvC = nullptr;
-
-  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
-  /// one we abort as the kernel is malformed.
-  CallBase *KernelDeinitCB = nullptr;
-
-  /// Flag to indicate if the associated function is a kernel entry.
-  bool IsKernelEntry = false;
-
-  /// State to track what kernel entries can reach the associated function.
-  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
-
-  /// State to indicate if we can track parallel level of the associated
-  /// function. We will give up tracking if we encounter unknown caller or the
-  /// caller is __kmpc_parallel_51.
-  BooleanStateWithSetVector<uint8_t> ParallelLevels;
-
-  /// Flag that indicates if the kernel has nested Parallelism
-  bool NestedParallelism = false;
-
-  /// Abstract State interface
-  ///{
-
-  KernelInfoState() = default;
-  KernelInfoState(bool BestState) {
-    if (!BestState)
-      indicatePessimisticFixpoint();
-  }
-
-  /// See AbstractState::isValidState(...)
-  bool isValidState() const override { return true; }
-
-  /// See AbstractState::isAtFixpoint(...)
-  bool isAtFixpoint() const override { return IsAtFixpoint; }
-
-  /// See AbstractState::indicatePessimisticFixpoint(...)
-  ChangeStatus indicatePessimisticFixpoint() override {
-    IsAtFixpoint = true;
-    ParallelLevels.indicatePessimisticFixpoint();
-    ReachingKernelEntries.indicatePessimisticFixpoint();
-    SPMDCompatibilityTracker.indicatePessimisticFixpoint();
-    ReachedKnownParallelRegions.indicatePessimisticFixpoint();
-    ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
-    NestedParallelism = true;
-    return ChangeStatus::CHANGED;
-  }
-
-  /// See AbstractState::indicateOptimisticFixpoint(...)
-  ChangeStatus indicateOptimisticFixpoint() override {
-    IsAtFixpoint = true;
-    ParallelLevels.indicateOptimisticFixpoint();
-    ReachingKernelEntries.indicateOptimisticFixpoint();
-    SPMDCompatibilityTracker.indicateOptimisticFixpoint();
-    ReachedKnownParallelRegions.indicateOptimisticFixpoint();
-    ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
-    return ChangeStatus::UNCHANGED;
-  }
-
-  /// Return the assumed state
-  KernelInfoState &getAssumed() { return *this; }
-  const KernelInfoState &getAssumed() const { return *this; }
-
-  bool operator==(const KernelInfoState &RHS) const {
-    if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
-      return false;
-    if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
-      return false;
-    if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
-      return false;
-    if (ReachingKernelEntries != RHS.ReachingKernelEntries)
-      return false;
-    if (ParallelLevels != RHS.ParallelLevels)
-      return false;
-    if (NestedParallelism != RHS.NestedParallelism)
-      return false;
-    return true;
-  }
-
-  /// Returns true if this kernel contains any OpenMP parallel regions.
-  bool mayContainParallelRegion() {
-    return !ReachedKnownParallelRegions.empty() ||
-           !ReachedUnknownParallelRegions.empty();
-  }
-
-  /// Return empty set as the best state of potential values.
-  static KernelInfoState getBestState() { return KernelInfoState(true); }
-
-  static KernelInfoState getBestState(KernelInfoState &KIS) {
-    return getBestState();
-  }
-
-  /// Return full set as the worst state of potential values.
-  static KernelInfoState getWorstState() { return KernelInfoState(false); }
-
-  /// "Clamp" this state with \p KIS.
-  KernelInfoState operator^=(const KernelInfoState &KIS) {
-    // Do not merge two different _init and _deinit call sites.
-    if (KIS.KernelInitCB) {
-      if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelInitCB = KIS.KernelInitCB;
-    }
-    if (KIS.KernelDeinitCB) {
-      if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelDeinitCB = KIS.KernelDeinitCB;
-    }
-    if (KIS.KernelEnvC) {
-      if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
-        llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
-                         "assumptions.");
-      KernelEnvC = KIS.KernelEnvC;
-    }
-    SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
-    ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
-    ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
-    NestedParallelism |= KIS.NestedParallelism;
-    return *this;
-  }
-
-  KernelInfoState operator&=(const KernelInfoState &KIS) {
-    return (*this ^= KIS);
-  }
-
-  ///}
-};
-
 /// Used to map the values physically (in the IR) stored in an offload
 /// array, to a vector in memory.
 struct OffloadArray {
@@ -3599,9 +3405,9 @@ struct AAHeapToSharedFunction : public AAHeapToShared {
   unsigned SharedMemoryUsed = 0;
 };
 
-struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
-  using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
-  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
+struct AAKernelInfoImpl : AAKernelInfo {
+  AAKernelInfoImpl(const IRPosition &IRP, Attributor &A)
+      : AAKernelInfo(IRP, A) {}
 
   /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
   /// unknown callees.
@@ -3637,28 +3443,17 @@ struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
            ", NestedPar: " + (NestedParallelism ? "yes" : "no");
   }
 
-  /// Create an abstract attribute biew for the position \p IRP.
-  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
-
-  /// See AbstractAttribute::getName()
-  const std::string getName() const override { return "AAKernelInfo"; }
-
-  /// See AbstractAttribute::getIdAddr()
-  const char *getIdAddr() const override { return &ID; }
-
-  /// This function should return true if the type of the \p AA is AAKernelInfo
-  static bool classof(const AbstractAttribute *AA) {
-    return (AA->getIdAddr() == &ID);
+  /// Return the ReachingKernelEntries
+  BooleanStateWithPtrSetVector<Function, false> getReachingKernels() override {
+    return ReachingKernelEntries;
   }
-
-  static const char ID;
 };
 
 /// The function kernel info abstract attribute, basically, what can we say
 /// about a function with regards to the KernelInfoState.
-struct AAKernelInfoFunction : AAKernelInfo {
+struct AAKernelInfoFunction : AAKernelInfoImpl {
   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
-      : AAKernelInfo(IRP, A) {}
+      : AAKernelInfoImpl(IRP, A) {}
 
   SmallPtrSet<Instruction *, 4> GuardedInstructions;
 
@@ -4862,13 +4657,13 @@ struct AAKernelInfoFunction : AAKernelInfo {
 /// The call site kernel info abstract attribute, basically, what can we say
 /// about a call site with regards to the KernelInfoState. For now this simply
 /// forwards the information from the callee.
-struct AAKernelInfoCallSite : AAKernelInfo {
+struct AAKernelInfoCallSite : AAKernelInfoImpl {
   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
-      : AAKernelInfo(IRP, A) {}
+      : AAKernelInfoImpl(IRP, A) {}
 
   /// See AbstractAttribute::initialize(...).
   void initialize(Attributor &A) override {
-    AAKernelInfo::initialize(A);
+    AAKernelInfoImpl::initialize(A);
 
     CallBase &CB = cast<CallBase>(getAssociatedValue());
     auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
@@ -4890,8 +4685,9 @@ struct AAKernelInfoCallSite : AAKernelInfo {
 
     // Next we check if we know the callee. If it is a known OpenMP function
     // we will handle them explicitly in the switch below. If it is not, we
-    // will use an AAKernelInfo object on the callee to gather information and
-    // merge that into the current state. The latter happens in the updateImpl.
+    // will use an AAKernelInfo object on the callee to gather information
+    // and merge that into the current state. The latter happens in the
+    // updateImpl.
     auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
       const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);



More information about the llvm-commits mailing list