[llvm] d2c37fc - [Attributor][FIX] Avoid dangling stack references in map

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 24 16:28:29 PDT 2023


Author: Johannes Doerfert
Date: 2023-08-24T16:28:10-07:00
New Revision: d2c37fc4f72caca4b82018a73b3d109c0fa191b9

URL: https://github.com/llvm/llvm-project/commit/d2c37fc4f72caca4b82018a73b3d109c0fa191b9
DIFF: https://github.com/llvm/llvm-project/commit/d2c37fc4f72caca4b82018a73b3d109c0fa191b9.diff

LOG: [Attributor][FIX] Avoid dangling stack references in map

The old code did not account for new queries during an update, which
caused us to leave stack RQIs in the map. We are now explicit about
temporary vs non-temporary RQIs.

Fixes: https://github.com/llvm/llvm-project/issues/64959

Added: 
    openmp/libomptarget/test/offloading/bug64959_compile_only.c

Modified: 
    llvm/lib/Transforms/IPO/AttributorAttributes.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 67167e4b13c025..66ecd49b1a6fdf 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -3541,24 +3541,24 @@ struct CachedReachabilityAA : public BaseTy {
   /// See AbstractAttribute::updateImpl(...).
   ChangeStatus updateImpl(Attributor &A) override {
     ChangeStatus Changed = ChangeStatus::UNCHANGED;
-    InUpdate = true;
     for (unsigned u = 0, e = QueryVector.size(); u < e; ++u) {
       RQITy *RQI = QueryVector[u];
-      if (RQI->Result == RQITy::Reachable::No && isReachableImpl(A, *RQI))
+      if (RQI->Result == RQITy::Reachable::No &&
+          isReachableImpl(A, *RQI, /*IsTemporaryRQI=*/false))
         Changed = ChangeStatus::CHANGED;
     }
-    InUpdate = false;
     return Changed;
   }
 
-  virtual bool isReachableImpl(Attributor &A, RQITy &RQI) = 0;
+  virtual bool isReachableImpl(Attributor &A, RQITy &RQI,
+                               bool IsTemporaryRQI) = 0;
 
   bool rememberResult(Attributor &A, typename RQITy::Reachable Result,
-                      RQITy &RQI, bool UsedExclusionSet) {
+                      RQITy &RQI, bool UsedExclusionSet, bool IsTemporaryRQI) {
     RQI.Result = Result;
 
     // Remove the temporary RQI from the cache.
-    if (!InUpdate)
+    if (IsTemporaryRQI)
       QueryCache.erase(&RQI);
 
     // Insert a plain RQI (w/o exclusion set) if that makes sense. Two options:
@@ -3576,7 +3576,7 @@ struct CachedReachabilityAA : public BaseTy {
     }
 
     // Check if we need to insert a new permanent RQI with the exclusion set.
-    if (!InUpdate && Result != RQITy::Reachable::Yes && UsedExclusionSet) {
+    if (IsTemporaryRQI && Result != RQITy::Reachable::Yes && UsedExclusionSet) {
       assert((!RQI.ExclusionSet || !RQI.ExclusionSet->empty()) &&
              "Did not expect empty set!");
       RQITy *RQIPtr = new (A.Allocator)
@@ -3588,7 +3588,7 @@ struct CachedReachabilityAA : public BaseTy {
       QueryCache.insert(RQIPtr);
     }
 
-    if (Result == RQITy::Reachable::No && !InUpdate)
+    if (Result == RQITy::Reachable::No && IsTemporaryRQI)
       A.registerForUpdate(*this);
     return Result == RQITy::Reachable::Yes;
   }
@@ -3629,7 +3629,6 @@ struct CachedReachabilityAA : public BaseTy {
   }
 
 private:
-  bool InUpdate = false;
   SmallVector<RQITy *> QueryVector;
   DenseSet<RQITy *> QueryCache;
 };
@@ -3653,7 +3652,8 @@ struct AAIntraFnReachabilityFunction final
     RQITy StackRQI(A, From, To, ExclusionSet, false);
     typename RQITy::Reachable Result;
     if (!NonConstThis->checkQueryCache(A, StackRQI, Result))
-      return NonConstThis->isReachableImpl(A, StackRQI);
+      return NonConstThis->isReachableImpl(A, StackRQI,
+                                           /*IsTemporaryRQI=*/true);
     return Result == RQITy::Reachable::Yes;
   }
 
@@ -3678,7 +3678,8 @@ struct AAIntraFnReachabilityFunction final
     return Base::updateImpl(A);
   }
 
-  bool isReachableImpl(Attributor &A, RQITy &RQI) override {
+  bool isReachableImpl(Attributor &A, RQITy &RQI,
+                       bool IsTemporaryRQI) override {
     const Instruction *Origin = RQI.From;
     bool UsedExclusionSet = false;
 
@@ -3704,12 +3705,14 @@ struct AAIntraFnReachabilityFunction final
     // possible.
     if (FromBB == ToBB &&
         WillReachInBlock(*RQI.From, *RQI.To, RQI.ExclusionSet))
-      return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet);
+      return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
+                            IsTemporaryRQI);
 
     // Check if reaching the ToBB block is sufficient or if even that would not
     // ensure reaching the target. In the latter case we are done.
     if (!WillReachInBlock(ToBB->front(), *RQI.To, RQI.ExclusionSet))
-      return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
+      return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
+                            IsTemporaryRQI);
 
     const Function *Fn = FromBB->getParent();
     SmallPtrSet<const BasicBlock *, 16> ExclusionBlocks;
@@ -3722,13 +3725,14 @@ struct AAIntraFnReachabilityFunction final
     if (ExclusionBlocks.count(FromBB) &&
         !WillReachInBlock(*RQI.From, *FromBB->getTerminator(),
                           RQI.ExclusionSet))
-      return rememberResult(A, RQITy::Reachable::No, RQI, true);
+      return rememberResult(A, RQITy::Reachable::No, RQI, true, IsTemporaryRQI);
 
     auto *LivenessAA =
         A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
     if (LivenessAA && LivenessAA->isAssumedDead(ToBB)) {
       DeadBlocks.insert(ToBB);
-      return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
+      return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
+                            IsTemporaryRQI);
     }
 
     SmallPtrSet<const BasicBlock *, 16> Visited;
@@ -3747,11 +3751,11 @@ struct AAIntraFnReachabilityFunction final
         }
         // We checked before if we just need to reach the ToBB block.
         if (SuccBB == ToBB)
-          return rememberResult(A, RQITy::Reachable::Yes, RQI,
-                                UsedExclusionSet);
+          return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
+                                IsTemporaryRQI);
         if (DT && ExclusionBlocks.empty() && DT->dominates(BB, ToBB))
-          return rememberResult(A, RQITy::Reachable::Yes, RQI,
-                                UsedExclusionSet);
+          return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
+                                IsTemporaryRQI);
 
         if (ExclusionBlocks.count(SuccBB)) {
           UsedExclusionSet = true;
@@ -3762,7 +3766,8 @@ struct AAIntraFnReachabilityFunction final
     }
 
     DeadEdges.insert(LocalDeadEdges.begin(), LocalDeadEdges.end());
-    return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
+    return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
+                          IsTemporaryRQI);
   }
 
   /// See AbstractAttribute::trackStatistics()
@@ -10646,22 +10651,19 @@ struct AAInterFnReachabilityFunction
     RQITy StackRQI(A, From, To, ExclusionSet, false);
     typename RQITy::Reachable Result;
     if (!NonConstThis->checkQueryCache(A, StackRQI, Result))
-      return NonConstThis->isReachableImpl(A, StackRQI);
+      return NonConstThis->isReachableImpl(A, StackRQI,
+                                           /*IsTemporaryRQI=*/true);
     return Result == RQITy::Reachable::Yes;
   }
 
-  bool isReachableImpl(Attributor &A, RQITy &RQI) override {
-    return isReachableImpl(A, RQI, nullptr);
-  }
-
   bool isReachableImpl(Attributor &A, RQITy &RQI,
-                       SmallPtrSet<const Function *, 16> *Visited) {
-
+                       bool IsTemporaryRQI) override {
     const Instruction *EntryI =
         &RQI.From->getFunction()->getEntryBlock().front();
     if (EntryI != RQI.From &&
         !instructionCanReach(A, *EntryI, *RQI.To, nullptr))
-      return rememberResult(A, RQITy::Reachable::No, RQI, false);
+      return rememberResult(A, RQITy::Reachable::No, RQI, false,
+                            IsTemporaryRQI);
 
     auto CheckReachableCallBase = [&](CallBase *CB) {
       auto *CBEdges = A.getAAFor<AACallEdges>(
@@ -10721,9 +10723,11 @@ struct AAInterFnReachabilityFunction
     if (!A.checkForAllCallLikeInstructions(CheckCallBase, *this,
                                            UsedAssumedInformation,
                                            /* CheckBBLivenessOnly */ true))
-      return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet);
+      return rememberResult(A, RQITy::Reachable::Yes, RQI, UsedExclusionSet,
+                            IsTemporaryRQI);
 
-    return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet);
+    return rememberResult(A, RQITy::Reachable::No, RQI, UsedExclusionSet,
+                          IsTemporaryRQI);
   }
 
   void trackStatistics() const override {}

diff  --git a/openmp/libomptarget/test/offloading/bug64959_compile_only.c b/openmp/libomptarget/test/offloading/bug64959_compile_only.c
new file mode 100644
index 00000000000000..310942f25c0078
--- /dev/null
+++ b/openmp/libomptarget/test/offloading/bug64959_compile_only.c
@@ -0,0 +1,40 @@
+// RUN: %libomptarget-compile-generic
+// RUN: %libomptarget-compileopt-generic
+
+#include <stdio.h>
+#define N 10
+
+int main(void) {
+  long int aa = 0;
+  int res = 0;
+
+  int ng = 12;
+  int cmom = 14;
+  int nxyz = 5000;
+
+#pragma omp target teams distribute num_teams(nxyz)                            \
+    thread_limit(ng *(cmom - 1)) map(tofrom : aa)
+  for (int gid = 0; gid < nxyz; gid++) {
+#pragma omp parallel for collapse(2)
+    for (unsigned int g = 0; g < ng; g++) {
+      for (unsigned int l = 0; l < cmom - 1; l++) {
+        int a = 0;
+#pragma omp parallel for reduction(+ : a)
+        for (int i = 0; i < N; i++) {
+          a += i;
+        }
+#pragma omp atomic
+        aa += a;
+      }
+    }
+  }
+  long exp = (long)ng * (cmom - 1) * nxyz * (N * (N - 1) / 2);
+  printf("The result is = %ld exp:%ld!\n", aa, exp);
+  if (aa != exp) {
+    printf("Failed %ld\n", aa);
+    return 1;
+  }
+  // CHECK: Success
+  printf("Success\n");
+  return 0;
+}


        


More information about the llvm-commits mailing list