[Openmp-commits] [openmp] e97e0a4 - [AbstractAttributor] Fold __kmpc_parallel_level if possible

Shilei Tian via Openmp-commits openmp-commits at lists.llvm.org
Mon Jul 26 19:46:26 PDT 2021


Author: Shilei Tian
Date: 2021-07-26T22:46:19-04:00
New Revision: e97e0a4fad091474131ad84d6f46009bf84c5b60

URL: https://github.com/llvm/llvm-project/commit/e97e0a4fad091474131ad84d6f46009bf84c5b60
DIFF: https://github.com/llvm/llvm-project/commit/e97e0a4fad091474131ad84d6f46009bf84c5b60.diff

LOG: [AbstractAttributor] Fold __kmpc_parallel_level if possible

Similar to D105787, this patch tries to fold `__kmpc_parallel_level` if possible.
Note that `__kmpc_parallel_level` doesn't take activeness into consideration,
based on current `deviceRTLs`, its return value can be such as 0, 1, 2, instead
of 0, 129, 130, etc. that also indicate activeness.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D106154

Added: 
    llvm/test/Transforms/OpenMP/parallel_level_fold.ll

Modified: 
    llvm/lib/Transforms/IPO/OpenMPOpt.cpp
    openmp/libomptarget/deviceRTLs/common/src/parallel.cu
    openmp/libomptarget/deviceRTLs/interface.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
index 5d4f4f47fb9a5..9150d951daa6b 100644
--- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
+++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
@@ -519,6 +519,11 @@ struct KernelInfoState : AbstractState {
   /// State to track what kernel entries can reach the associated function.
   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
 
+  /// State to indicate if we can track parallel level of the associated
+  /// function. We will give up tracking if we encounter unknown caller or the
+  /// caller is __kmpc_parallel_51.
+  BooleanStateWithSetVector<uint8_t> ParallelLevels;
+
   /// Abstract State interface
   ///{
 
@@ -3329,8 +3334,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
 
-    if (!IsKernelEntry)
+    if (!IsKernelEntry) {
       updateReachingKernelEntries(A);
+      updateParallelLevels(A);
+    }
 
     // Callback to check a call instruction.
     bool AllSPMDStatesWereFixed = true;
@@ -3386,6 +3393,49 @@ struct AAKernelInfoFunction : AAKernelInfo {
                                 AllCallSitesKnown))
       ReachingKernelEntries.indicatePessimisticFixpoint();
   }
+
+  /// Update info regarding parallel levels.
+  void updateParallelLevels(Attributor &A) {
+    auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
+    OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
+        OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
+
+    auto PredCallSite = [&](AbstractCallSite ACS) {
+      Function *Caller = ACS.getInstruction()->getFunction();
+
+      assert(Caller && "Caller is nullptr");
+
+      auto &CAA =
+          A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
+      if (CAA.ParallelLevels.isValidState()) {
+        // Any function that is called by `__kmpc_parallel_51` will not be
+        // folded as the parallel level in the function is updated. In order to
+        // get it right, all the analysis would depend on the implentation. That
+        // said, if in the future any change to the implementation, the analysis
+        // could be wrong. As a consequence, we are just conservative here.
+        if (Caller == Parallel51RFI.Declaration) {
+          ParallelLevels.indicatePessimisticFixpoint();
+          return true;
+        }
+
+        ParallelLevels ^= CAA.ParallelLevels;
+
+        return true;
+      }
+
+      // We lost track of the caller of the associated function, any kernel
+      // could reach now.
+      ParallelLevels.indicatePessimisticFixpoint();
+
+      return true;
+    };
+
+    bool AllCallSitesKnown = true;
+    if (!A.checkForAllCallSites(PredCallSite, *this,
+                                true /* RequireAllCallSites */,
+                                AllCallSitesKnown))
+      ParallelLevels.indicatePessimisticFixpoint();
+  }
 };
 
 /// The call site kernel info abstract attribute, basically, what can we say
@@ -3668,6 +3718,9 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
     case OMPRTL___kmpc_is_generic_main_thread_id:
       Changed |= foldIsGenericMainThread(A);
       break;
+    case OMPRTL___kmpc_parallel_level:
+      Changed |= foldParallelLevel(A);
+      break;
     default:
       llvm_unreachable("Unhandled OpenMP runtime function!");
     }
@@ -3782,6 +3835,68 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
                                                     : ChangeStatus::CHANGED;
   }
 
+  /// Fold __kmpc_parallel_level into a constant if possible.
+  ChangeStatus foldParallelLevel(Attributor &A) {
+    Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
+
+    auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
+        *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
+
+    if (!CallerKernelInfoAA.ParallelLevels.isValidState())
+      return indicatePessimisticFixpoint();
+
+    if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
+      return indicatePessimisticFixpoint();
+
+    if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
+      assert(!SimplifiedValue.hasValue() &&
+             "SimplifiedValue should keep none at this point");
+      return ChangeStatus::UNCHANGED;
+    }
+
+    unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
+    unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
+    for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
+      auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
+                                          DepClassTy::REQUIRED);
+      if (!AA.SPMDCompatibilityTracker.isValidState())
+        return indicatePessimisticFixpoint();
+
+      if (AA.SPMDCompatibilityTracker.isAssumed()) {
+        if (AA.SPMDCompatibilityTracker.isAtFixpoint())
+          ++KnownSPMDCount;
+        else
+          ++AssumedSPMDCount;
+      } else {
+        if (AA.SPMDCompatibilityTracker.isAtFixpoint())
+          ++KnownNonSPMDCount;
+        else
+          ++AssumedNonSPMDCount;
+      }
+    }
+
+    if ((AssumedSPMDCount + KnownSPMDCount) &&
+        (AssumedNonSPMDCount + KnownNonSPMDCount))
+      return indicatePessimisticFixpoint();
+
+    auto &Ctx = getAnchorValue().getContext();
+    // If the caller can only be reached by SPMD kernel entries, the parallel
+    // level is 1. Similarly, if the caller can only be reached by non-SPMD
+    // kernel entries, it is 0.
+    if (AssumedSPMDCount || KnownSPMDCount) {
+      assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
+             "Expected only SPMD kernels!");
+      SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
+    } else {
+      assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
+             "Expected only non-SPMD kernels!");
+      SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
+    }
+
+    return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
+                                                    : ChangeStatus::CHANGED;
+  }
+
   /// An optional value the associated value is assumed to fold to. That is, we
   /// assume the associated value (which is a call) can be replaced by this
   /// simplified value.
@@ -3832,6 +3947,19 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
           /* UpdateAfterInit */ false);
       return false;
     });
+
+    auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level];
+    ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) {
+      CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI);
+      if (!CI)
+        return false;
+      A.getOrCreateAAFor<AAFoldRuntimeCall>(
+          IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
+          DepClassTy::NONE, /* ForceUpdate */ false,
+          /* UpdateAfterInit */ false);
+
+      return false;
+    });
   }
 
   // Create CallSite AA for all Getters.

diff  --git a/llvm/test/Transforms/OpenMP/parallel_level_fold.ll b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll
new file mode 100644
index 0000000000000..29afba97ab56e
--- /dev/null
+++ b/llvm/test/Transforms/OpenMP/parallel_level_fold.ll
@@ -0,0 +1,150 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-globals
+; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s
+target triple = "nvptx64"
+
+%struct.ident_t = type { i32, i32, i32, i32, i8* }
+
+ at no_spmd_exec_mode = weak constant i8 1
+ at spmd_exec_mode = weak constant i8 0
+ at parallel_exec_mode = weak constant i8 0
+ at G = external global i8
+ at llvm.compiler.used = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata"
+
+;.
+; CHECK: @[[NO_SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 1
+; CHECK: @[[SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0
+; CHECK: @[[PARALLEL_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0
+; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i8
+; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata"
+;.
+define weak void @none_spmd() {
+; CHECK-LABEL: define {{[^@]+}}@none_spmd() {
+; CHECK-NEXT:    [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
+; CHECK-NEXT:    call void @none_spmd_helper()
+; CHECK-NEXT:    call void @mixed_helper()
+; CHECK-NEXT:    call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
+  call void @none_spmd_helper()
+  call void @mixed_helper()
+  call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
+  ret void
+}
+
+define weak void @spmd() {
+; CHECK-LABEL: define {{[^@]+}}@spmd() {
+; CHECK-NEXT:    [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
+; CHECK-NEXT:    call void @spmd_helper()
+; CHECK-NEXT:    call void @mixed_helper()
+; CHECK-NEXT:    call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
+  call void @spmd_helper()
+  call void @mixed_helper()
+  call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
+  ret void
+}
+
+define weak void @parallel() {
+; CHECK-LABEL: define {{[^@]+}}@parallel() {
+; CHECK-NEXT:    [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* align 536870912 null, i1 true, i1 false, i1 false)
+; CHECK-NEXT:    call void @spmd_helper()
+; CHECK-NEXT:    call void @__kmpc_parallel_51(%struct.ident_t* noalias noundef align 536870912 null, i32 noundef 0, i32 noundef 0, i32 noundef 0, i32 noundef 0, i8* noalias noundef align 536870912 null, i8* noalias noundef align 536870912 null, i8** noalias noundef align 536870912 null, i64 noundef 0)
+; CHECK-NEXT:    call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
+; CHECK-NEXT:    ret void
+;
+  %i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
+  call void @spmd_helper()
+  call void @__kmpc_parallel_51(%struct.ident_t* null, i32 0, i32 0, i32 0, i32 0, i8* null, i8* null, i8** null, i64 0)
+  call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
+  ret void
+}
+
+define internal void @mixed_helper() {
+; CHECK-LABEL: define {{[^@]+}}@mixed_helper() {
+; CHECK-NEXT:    [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level()
+; CHECK-NEXT:    store i8 [[LEVEL]], i8* @G, align 1
+; CHECK-NEXT:    ret void
+;
+  %level = call i8 @__kmpc_parallel_level()
+  store i8 %level, i8* @G
+  ret void
+}
+
+define internal void @none_spmd_helper() {
+; CHECK-LABEL: define {{[^@]+}}@none_spmd_helper() {
+; CHECK-NEXT:    [[LEVEL12:%.*]] = call i8 @__kmpc_parallel_level()
+; CHECK-NEXT:    [[C:%.*]] = icmp eq i8 [[LEVEL12]], 0
+; CHECK-NEXT:    br i1 [[C]], label [[T:%.*]], label [[F:%.*]]
+; CHECK:       t:
+; CHECK-NEXT:    call void @foo()
+; CHECK-NEXT:    ret void
+; CHECK:       f:
+; CHECK-NEXT:    call void @bar()
+; CHECK-NEXT:    ret void
+;
+  %level12 = call i8 @__kmpc_parallel_level()
+  %c = icmp eq i8 %level12, 0
+  br i1 %c, label %t, label %f
+t:
+  call void @foo()
+  ret void
+f:
+  call void @bar()
+  ret void
+}
+
+define internal void @spmd_helper() {
+; CHECK-LABEL: define {{[^@]+}}@spmd_helper() {
+; CHECK-NEXT:    store i8 1, i8* @G, align 1
+; CHECK-NEXT:    ret void
+;
+  %level = call i8 @__kmpc_parallel_level()
+  store i8 %level, i8* @G
+  ret void
+}
+
+define internal void @__kmpc_parallel_51(%struct.ident_t*, i32, i32, i32, i32, i8*, i8*, i8**, i64) {
+; CHECK-LABEL: define {{[^@]+}}@__kmpc_parallel_51
+; CHECK-SAME: (%struct.ident_t* noalias nocapture nofree readnone align 536870912 [[TMP0:%.*]], i32 [[TMP1:%.*]], i32 [[TMP2:%.*]], i32 [[TMP3:%.*]], i32 [[TMP4:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP5:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP6:%.*]], i8** noalias nocapture nofree readnone align 536870912 [[TMP7:%.*]], i64 [[TMP8:%.*]]) {
+; CHECK-NEXT:    call void @parallel_helper()
+; CHECK-NEXT:    ret void
+;
+  call void @parallel_helper()
+  ret void
+}
+
+define internal void @parallel_helper() {
+; CHECK-LABEL: define {{[^@]+}}@parallel_helper() {
+; CHECK-NEXT:    [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level()
+; CHECK-NEXT:    store i8 [[LEVEL]], i8* @G, align 1
+; CHECK-NEXT:    ret void
+;
+  %level = call i8 @__kmpc_parallel_level()
+  store i8 %level, i8* @G
+  ret void
+}
+
+declare void @foo()
+declare void @bar()
+declare i8 @__kmpc_parallel_level()
+declare i32 @__kmpc_target_init(%struct.ident_t*, i1 zeroext, i1 zeroext, i1 zeroext) #1
+declare void @__kmpc_target_deinit(%struct.ident_t* nocapture readnone, i1 zeroext, i1 zeroext) #1
+
+!llvm.module.flags = !{!0, !1}
+!nvvm.annotations = !{!2, !3, !4}
+
+!0 = !{i32 7, !"openmp", i32 50}
+!1 = !{i32 7, !"openmp-device", i32 50}
+!2 = !{void ()* @none_spmd, !"kernel", i32 1}
+!3 = !{void ()* @spmd, !"kernel", i32 1}
+!4 = !{void ()* @parallel, !"kernel", i32 1}
+;.
+; CHECK: [[META0:![0-9]+]] = !{i32 7, !"openmp", i32 50}
+; CHECK: [[META1:![0-9]+]] = !{i32 7, !"openmp-device", i32 50}
+; CHECK: [[META2:![0-9]+]] = !{void ()* @none_spmd, !"kernel", i32 1}
+; CHECK: [[META3:![0-9]+]] = !{void ()* @spmd, !"kernel", i32 1}
+; CHECK: [[META4:![0-9]+]] = !{void ()* @parallel, !"kernel", i32 1}
+;.

diff  --git a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu
index 2656f3e48ce3f..b12d5ccb5a17e 100644
--- a/openmp/libomptarget/deviceRTLs/common/src/parallel.cu
+++ b/openmp/libomptarget/deviceRTLs/common/src/parallel.cu
@@ -239,7 +239,7 @@ EXTERN void __kmpc_end_serialized_parallel(kmp_Ident *loc,
   currTaskDescr->RestoreLoopData();
 }
 
-EXTERN uint8_t __kmpc_parallel_level() {
+NOINLINE EXTERN uint8_t __kmpc_parallel_level() {
   return parallelLevel[GetWarpId()] & (OMP_ACTIVE_PARALLEL_LEVEL - 1);
 }
 
@@ -282,11 +282,11 @@ EXTERN void __kmpc_push_proc_bind(kmp_Ident *loc, uint32_t tid, int proc_bind) {
 // parallel interface
 ////////////////////////////////////////////////////////////////////////////////
 
-EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid,
-                               kmp_int32 if_expr, kmp_int32 num_threads,
-                               int proc_bind, void *fn, void *wrapper_fn,
-                               void **args, size_t nargs) {
-
+NOINLINE EXTERN void __kmpc_parallel_51(kmp_Ident *ident, kmp_int32 global_tid,
+                                        kmp_int32 if_expr,
+                                        kmp_int32 num_threads, int proc_bind,
+                                        void *fn, void *wrapper_fn, void **args,
+                                        size_t nargs) {
   // Handle the serialized case first, same for SPMD/non-SPMD except that in
   // SPMD mode we already incremented the parallel level counter, account for
   // that.

diff  --git a/openmp/libomptarget/deviceRTLs/interface.h b/openmp/libomptarget/deviceRTLs/interface.h
index ade1bfe2222eb..2e80dc3a82ac9 100644
--- a/openmp/libomptarget/deviceRTLs/interface.h
+++ b/openmp/libomptarget/deviceRTLs/interface.h
@@ -222,7 +222,7 @@ EXTERN void __kmpc_push_num_threads(kmp_Ident *loc, int32_t global_tid,
                                     int32_t num_threads);
 EXTERN void __kmpc_serialized_parallel(kmp_Ident *loc, uint32_t global_tid);
 EXTERN void __kmpc_end_serialized_parallel(kmp_Ident *loc, uint32_t global_tid);
-EXTERN uint8_t __kmpc_parallel_level();
+NOINLINE EXTERN uint8_t __kmpc_parallel_level();
 
 // proc bind
 EXTERN void __kmpc_push_proc_bind(kmp_Ident *loc, uint32_t global_tid,
@@ -441,10 +441,11 @@ EXTERN void __kmpc_get_shared_variables(void ***GlobalArgs);
 /// \param wrapper_fn  The worker wrapper function of fn.
 /// \param args        The pointer array of arguments to fn.
 /// \param nargs       The number of arguments to fn.
-EXTERN void __kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid,
-                               kmp_int32 if_expr, kmp_int32 num_threads,
-                               int proc_bind, void *fn, void *wrapper_fn,
-                               void **args, size_t nargs);
+NOINLINE EXTERN void __kmpc_parallel_51(ident_t *ident, kmp_int32 global_tid,
+                                        kmp_int32 if_expr,
+                                        kmp_int32 num_threads, int proc_bind,
+                                        void *fn, void *wrapper_fn, void **args,
+                                        size_t nargs);
 
 // SPMD execution mode interrogation function.
 EXTERN int8_t __kmpc_is_spmd_exec_mode();


        


More information about the Openmp-commits mailing list