[llvm] cefa5ce - [OpenMP] Replace ExternalizationRAII with virtual uses

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 12 00:21:09 PST 2023


Author: Johannes Doerfert
Date: 2023-01-12T00:14:06-08:00
New Revision: cefa5cefdce2d5090002c3116403f7e5ca5700b9

URL: https://github.com/llvm/llvm-project/commit/cefa5cefdce2d5090002c3116403f7e5ca5700b9
DIFF: https://github.com/llvm/llvm-project/commit/cefa5cefdce2d5090002c3116403f7e5ca5700b9.diff

LOG: [OpenMP] Replace ExternalizationRAII with virtual uses

The externalization was always a stopgap solution. One of the drawbacks
is that it is very conservative no matter if we actually require the
functions at the end of the pass. The new concept is more generic and
properly integrates into the dependence graph. Whenever we might need a
function, it has a "virtual use" that cannot be analyzed. If we do not
because of some AA state, there will be a dependence to ensure state
changes trigger revisits of uses, including a potentially new virtual
use.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/IPO/Attributor.h
    llvm/lib/Transforms/IPO/Attributor.cpp
    llvm/lib/Transforms/IPO/OpenMPOpt.cpp
    llvm/test/Transforms/OpenMP/get_hardware_num_threads_in_block_fold.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index 41b555b8e5b7a..98af41d8b2ac3 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -1885,11 +1885,21 @@ struct Attributor {
     return SimplificationCallbacks.count(IRP);
   }
 
+  using VirtualUseCallbackTy =
+      std::function<bool(Attributor &, const AbstractAttribute *)>;
+  void registerVirtualUseCallback(const Value &V,
+                                  const VirtualUseCallbackTy &CB) {
+    VirtualUseCallbacks[&V].emplace_back(CB);
+  }
+
 private:
   /// The vector with all simplification callbacks registered by outside AAs.
   DenseMap<IRPosition, SmallVector<SimplifictionCallbackTy, 1>>
       SimplificationCallbacks;
 
+  DenseMap<const Value *, SmallVector<VirtualUseCallbackTy, 1>>
+      VirtualUseCallbacks;
+
 public:
   /// Translate \p V from the callee context into the call site context.
   std::optional<Value *>

diff  --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index c3cb22dabfc81..30450c350c003 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -1472,6 +1472,11 @@ bool Attributor::checkForAllUses(
     bool IgnoreDroppableUses,
     function_ref<bool(const Use &OldU, const Use &NewU)> EquivalentUseCB) {
 
+  // Check virtual uses first.
+  for (VirtualUseCallbackTy &CB : VirtualUseCallbacks.lookup(&V))
+    if (!CB(*this, &QueryingAA))
+      return false;
+
   // Check the trivial case first as it catches void values.
   if (V.use_empty())
     return true;
@@ -1611,6 +1616,10 @@ bool Attributor::checkForAllCallSites(function_ref<bool(AbstractCallSite)> Pred,
         << " has no internal linkage, hence not all call sites are known\n");
     return false;
   }
+  // Check virtual uses first.
+  for (VirtualUseCallbackTy &CB : VirtualUseCallbacks.lookup(&Fn))
+    if (!CB(*this, QueryingAA))
+      return false;
 
   SmallVector<const Use *, 8> Uses(make_pointer_range(Fn.uses()));
   for (unsigned u = 0; u < Uses.size(); ++u) {

diff  --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 5ea3d2b5acb4b..89723a3d6c0df 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -2087,30 +2087,6 @@ struct OpenMPOpt {
           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
   }
 
-  /// RAII struct to temporarily change an RTL function's linkage to external.
-  /// This prevents it from being mistakenly removed by other optimizations.
-  struct ExternalizationRAII {
-    ExternalizationRAII(OMPInformationCache &OMPInfoCache,
-                        RuntimeFunction RFKind)
-        : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
-      if (!Declaration)
-        return;
-
-      LinkageType = Declaration->getLinkage();
-      Declaration->setLinkage(GlobalValue::ExternalLinkage);
-    }
-
-    ~ExternalizationRAII() {
-      if (!Declaration)
-        return;
-
-      Declaration->setLinkage(LinkageType);
-    }
-
-    Function *Declaration;
-    GlobalValue::LinkageTypes LinkageType;
-  };
-
   /// The underlying module.
   Module &M;
 
@@ -2135,21 +2111,6 @@ struct OpenMPOpt {
     if (SCC.empty())
       return false;
 
-    // Temporarily make these function have external linkage so the Attributor
-    // doesn't remove them when we try to look them up later.
-    ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
-    ExternalizationRAII EndParallel(OMPInfoCache,
-                                    OMPRTL___kmpc_kernel_end_parallel);
-    ExternalizationRAII BarrierSPMD(OMPInfoCache,
-                                    OMPRTL___kmpc_barrier_simple_spmd);
-    ExternalizationRAII BarrierGeneric(OMPInfoCache,
-                                       OMPRTL___kmpc_barrier_simple_generic);
-    ExternalizationRAII ThreadId(OMPInfoCache,
-                                 OMPRTL___kmpc_get_hardware_thread_id_in_block);
-    ExternalizationRAII NumThreads(
-        OMPInfoCache, OMPRTL___kmpc_get_hardware_num_threads_in_block);
-    ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size);
-
     registerAAs(IsModulePass);
 
     ChangeStatus Changed = A.run();
@@ -3296,27 +3257,6 @@ struct AAKernelInfoFunction : AAKernelInfo {
       return Val;
     };
 
-    Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
-        [&](const IRPosition &IRP, const AbstractAttribute *AA,
-            bool &UsedAssumedInformation) -> std::optional<Value *> {
-      // IRP represents the "RequiresFullRuntime" argument of an
-      // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
-      // one with the internal state of the SPMDCompatibilityTracker, so if
-      // generic then true, if SPMD then false.
-      if (!SPMDCompatibilityTracker.isValidState())
-        return nullptr;
-      if (!SPMDCompatibilityTracker.isAtFixpoint()) {
-        if (AA)
-          A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
-        UsedAssumedInformation = true;
-      } else {
-        UsedAssumedInformation = false;
-      }
-      auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
-                                       !SPMDCompatibilityTracker.isAssumed());
-      return Val;
-    };
-
     constexpr const int InitModeArgNo = 1;
     constexpr const int DeinitModeArgNo = 1;
     constexpr const int InitUseStateMachineArgNo = 2;
@@ -3338,6 +3278,84 @@ struct AAKernelInfoFunction : AAKernelInfo {
     // This is a generic region but SPMDization is disabled so stop tracking.
     else if (DisableOpenMPOptSPMDization)
       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
+
+    // Register virtual uses of functions we might need to preserve.
+    auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
+                                  Attributor::VirtualUseCallbackTy &CB) {
+      if (!OMPInfoCache.RFIs[RFKind].Declaration)
+        return;
+      A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
+    };
+
+    // Add a dependence to ensure updates if the state changes.
+    auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
+                            const AbstractAttribute *QueryingAA) {
+      if (QueryingAA) {
+        A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
+      }
+      return true;
+    };
+
+    Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
+        [&](Attributor &A, const AbstractAttribute *QueryingAA) {
+          // Whenever we create a custom state machine we will insert calls to
+          // __kmpc_get_hardware_num_threads_in_block,
+          // __kmpc_get_warp_size,
+          // __kmpc_barrier_simple_generic,
+          // __kmpc_kernel_parallel, and
+          // __kmpc_kernel_end_parallel.
+          // Not needed if we are on track for SPMDzation.
+          if (SPMDCompatibilityTracker.isValidState())
+            return AddDependence(A, this, QueryingAA);
+          // Not needed if we can't rewrite due to an invalid state.
+          if (!ReachedKnownParallelRegions.isValidState())
+            return AddDependence(A, this, QueryingAA);
+          return false;
+        };
+
+    // Not needed if we are pre-runtime merge.
+    if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
+      RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
+                         CustomStateMachineUseCB);
+      RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
+      RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
+                         CustomStateMachineUseCB);
+      RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
+                         CustomStateMachineUseCB);
+      RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
+                         CustomStateMachineUseCB);
+    }
+
+    // If we do not perform SPMDzation we do not need the virtual uses below.
+    if (SPMDCompatibilityTracker.isAtFixpoint())
+      return;
+
+    Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
+        [&](Attributor &A, const AbstractAttribute *QueryingAA) {
+          // Whenever we perform SPMDzation we will insert
+          // __kmpc_get_hardware_thread_id_in_block calls.
+          if (!SPMDCompatibilityTracker.isValidState())
+            return AddDependence(A, this, QueryingAA);
+          return false;
+        };
+    RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
+                       HWThreadIdUseCB);
+
+    Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
+        [&](Attributor &A, const AbstractAttribute *QueryingAA) {
+          // Whenever we perform SPMDzation with guarding we will insert
+          // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
+          // nothing to guard, or there are no parallel regions, we don't need
+          // the calls.
+          if (!SPMDCompatibilityTracker.isValidState())
+            return AddDependence(A, this, QueryingAA);
+          if (SPMDCompatibilityTracker.empty())
+            return AddDependence(A, this, QueryingAA);
+          if (!mayContainParallelRegion())
+            return AddDependence(A, this, QueryingAA);
+          return false;
+        };
+    RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
   }
 
   /// Sanitize the string \p S such that it is a suitable global symbol name.

diff  --git a/llvm/test/Transforms/OpenMP/get_hardware_num_threads_in_block_fold.ll b/llvm/test/Transforms/OpenMP/get_hardware_num_threads_in_block_fold.ll
index 5376a88d55473..82546309e7608 100644
--- a/llvm/test/Transforms/OpenMP/get_hardware_num_threads_in_block_fold.ll
+++ b/llvm/test/Transforms/OpenMP/get_hardware_num_threads_in_block_fold.ll
@@ -12,6 +12,9 @@ target triple = "nvptx64"
 ; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i32
 ; CHECK: @[[KERNEL1_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 3
 ; CHECK: @[[KERNEL2_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 3
+; CHECK: @[[KERNEL0_NESTED_PARALLELISM:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0
+; CHECK: @[[KERNEL1_NESTED_PARALLELISM:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0
+; CHECK: @[[KERNEL2_NESTED_PARALLELISM:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0
 ; CHECK: @[[GLOB0:[0-9]+]] = private unnamed_addr constant [23 x i8] c"
 ; CHECK: @[[GLOB1:[0-9]+]] = private unnamed_addr constant [[STRUCT_IDENT_T:%.*]] { i32 0, i32 2, i32 0, i32 22, ptr @[[GLOB0]] }, align 8
 ;.
@@ -191,11 +194,6 @@ entry:
 }
 
 define internal i32 @__kmpc_get_hardware_num_threads_in_block() {
-; CHECK-LABEL: define {{[^@]+}}@__kmpc_get_hardware_num_threads_in_block
-; CHECK-SAME: () #[[ATTR1]] {
-; CHECK-NEXT:    [[RET:%.*]] = call i32 @__kmpc_get_hardware_num_threads_in_block_dummy()
-; CHECK-NEXT:    ret i32 [[RET]]
-;
   %ret = call i32 @__kmpc_get_hardware_num_threads_in_block_dummy()
   ret i32 %ret
 }


        


More information about the llvm-commits mailing list