[llvm] 5890b30 - [LAA] Initial support for runtime checks with pointer selects.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu May 12 11:34:09 PDT 2022


Author: Florian Hahn
Date: 2022-05-12T19:33:48+01:00
New Revision: 5890b30105999a137e72e42f3760bebfd77001ca

URL: https://github.com/llvm/llvm-project/commit/5890b30105999a137e72e42f3760bebfd77001ca
DIFF: https://github.com/llvm/llvm-project/commit/5890b30105999a137e72e42f3760bebfd77001ca.diff

LOG: [LAA] Initial support for runtime checks with pointer selects.

Scaffolding support for generating runtime checks for multiple SCEV expressions
per pointer. The initial version just adds support for looking through
a single pointer select.

The more sophisticated logic for analyzing forks is in D108699

Reviewed By: huntergr

Differential Revision: https://reviews.llvm.org/D114487

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/LoopAccessAnalysis.h
    llvm/lib/Analysis/LoopAccessAnalysis.cpp
    llvm/test/Analysis/LoopAccessAnalysis/forked-pointers.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index 9a9a984d2c41d..70d40bc2a8081 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -406,8 +406,8 @@ class RuntimePointerChecking {
   /// according to the assumptions that we've made during the analysis.
   /// The method might also version the pointer stride according to \p Strides,
   /// and add new predicates to \p PSE.
-  void insert(Loop *Lp, Value *Ptr, Type *AccessTy, bool WritePtr,
-              unsigned DepSetId, unsigned ASId, const ValueToValueMap &Strides,
+  void insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr, Type *AccessTy,
+              bool WritePtr, unsigned DepSetId, unsigned ASId,
               PredicatedScalarEvolution &PSE);
 
   /// No run-time memory checking is necessary.

diff  --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index d0276df533ffc..aff272c05a937 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -47,6 +47,7 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/IR/ValueHandle.h"
@@ -65,6 +66,7 @@
 #include <vector>
 
 using namespace llvm;
+using namespace llvm::PatternMatch;
 
 #define DEBUG_TYPE "loop-accesses"
 
@@ -188,22 +190,19 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
 ///
 /// There is no conflict when the intervals are disjoint:
 /// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End)
-void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, Type *AccessTy,
-                                    bool WritePtr, unsigned DepSetId,
-                                    unsigned ASId,
-                                    const ValueToValueMap &Strides,
+void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
+                                    Type *AccessTy, bool WritePtr,
+                                    unsigned DepSetId, unsigned ASId,
                                     PredicatedScalarEvolution &PSE) {
-  // Get the stride replaced scev.
-  const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
   ScalarEvolution *SE = PSE.getSE();
 
   const SCEV *ScStart;
   const SCEV *ScEnd;
 
-  if (SE->isLoopInvariant(Sc, Lp)) {
-    ScStart = ScEnd = Sc;
+  if (SE->isLoopInvariant(PtrExpr, Lp)) {
+    ScStart = ScEnd = PtrExpr;
   } else {
-    const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc);
+    const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrExpr);
     assert(AR && "Invalid addrec expression");
     const SCEV *Ex = PSE.getBackedgeTakenCount();
 
@@ -230,7 +229,7 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, Type *AccessTy,
   const SCEV *EltSizeSCEV = SE->getStoreSizeOfExpr(IdxTy, AccessTy);
   ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV);
 
-  Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc);
+  Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, PtrExpr);
 }
 
 SmallVector<RuntimePointerCheck, 4>
@@ -370,9 +369,11 @@ void RuntimePointerChecking::groupChecks(
 
   unsigned TotalComparisons = 0;
 
-  DenseMap<Value *, unsigned> PositionMap;
-  for (unsigned Index = 0; Index < Pointers.size(); ++Index)
-    PositionMap[Pointers[Index].PointerValue] = Index;
+  DenseMap<Value *, SmallVector<unsigned>> PositionMap;
+  for (unsigned Index = 0; Index < Pointers.size(); ++Index) {
+    auto Iter = PositionMap.insert({Pointers[Index].PointerValue, {}});
+    Iter.first->second.push_back(Index);
+  }
 
   // We need to keep track of what pointers we've already seen so we
   // don't process them twice.
@@ -403,34 +404,35 @@ void RuntimePointerChecking::groupChecks(
       auto PointerI = PositionMap.find(MI->getPointer());
       assert(PointerI != PositionMap.end() &&
              "pointer in equivalence class not found in PositionMap");
-      unsigned Pointer = PointerI->second;
-      bool Merged = false;
-      // Mark this pointer as seen.
-      Seen.insert(Pointer);
-
-      // Go through all the existing sets and see if we can find one
-      // which can include this pointer.
-      for (RuntimeCheckingPtrGroup &Group : Groups) {
-        // Don't perform more than a certain amount of comparisons.
-        // This should limit the cost of grouping the pointers to something
-        // reasonable.  If we do end up hitting this threshold, the algorithm
-        // will create separate groups for all remaining pointers.
-        if (TotalComparisons > MemoryCheckMergeThreshold)
-          break;
-
-        TotalComparisons++;
-
-        if (Group.addPointer(Pointer, *this)) {
-          Merged = true;
-          break;
+      for (unsigned Pointer : PointerI->second) {
+        bool Merged = false;
+        // Mark this pointer as seen.
+        Seen.insert(Pointer);
+
+        // Go through all the existing sets and see if we can find one
+        // which can include this pointer.
+        for (RuntimeCheckingPtrGroup &Group : Groups) {
+          // Don't perform more than a certain amount of comparisons.
+          // This should limit the cost of grouping the pointers to something
+          // reasonable.  If we do end up hitting this threshold, the algorithm
+          // will create separate groups for all remaining pointers.
+          if (TotalComparisons > MemoryCheckMergeThreshold)
+            break;
+
+          TotalComparisons++;
+
+          if (Group.addPointer(Pointer, *this)) {
+            Merged = true;
+            break;
+          }
         }
-      }
 
-      if (!Merged)
-        // We couldn't add this pointer to any existing set or the threshold
-        // for the number of comparisons has been reached. Create a new group
-        // to hold the current pointer.
-        Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this));
+        if (!Merged)
+          // We couldn't add this pointer to any existing set or the threshold
+          // for the number of comparisons has been reached. Create a new group
+          // to hold the current pointer.
+          Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this));
+      }
     }
 
     // We've computed the grouped checks for this partition.
@@ -629,11 +631,8 @@ class AccessAnalysis {
 /// Check whether a pointer can participate in a runtime bounds check.
 /// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr
 /// by adding run-time checks (overflow checks) if necessary.
-static bool hasComputableBounds(PredicatedScalarEvolution &PSE,
-                                const ValueToValueMap &Strides, Value *Ptr,
-                                Loop *L, bool Assume) {
-  const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);
-
+static bool hasComputableBounds(PredicatedScalarEvolution &PSE, Value *Ptr,
+                                const SCEV *PtrScev, Loop *L, bool Assume) {
   // The bounds for loop-invariant pointer is trivial.
   if (PSE.getSE()->isLoopInvariant(PtrScev, L))
     return true;
@@ -696,34 +695,56 @@ bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck,
                                           bool Assume) {
   Value *Ptr = Access.getPointer();
 
-  if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume))
-    return false;
+  ScalarEvolution &SE = *PSE.getSE();
+  SmallVector<const SCEV *> TranslatedPtrs;
+  if (auto *SI = dyn_cast<SelectInst>(Ptr))
+    TranslatedPtrs = {SE.getSCEV(SI->getOperand(1)),
+                      SE.getSCEV(SI->getOperand(2))};
+  else
+    TranslatedPtrs = {replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr)};
 
-  // When we run after a failing dependency check we have to make sure
-  // we don't have wrapping pointers.
-  if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, AccessTy, TheLoop)) {
-    auto *Expr = PSE.getSCEV(Ptr);
-    if (!Assume || !isa<SCEVAddRecExpr>(Expr))
+  for (const SCEV *PtrExpr : TranslatedPtrs) {
+    if (!hasComputableBounds(PSE, Ptr, PtrExpr, TheLoop, Assume))
       return false;
-    PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
+
+    // When we run after a failing dependency check we have to make sure
+    // we don't have wrapping pointers.
+    if (ShouldCheckWrap) {
+      // Skip wrap checking when translating pointers.
+      if (TranslatedPtrs.size() > 1)
+        return false;
+
+      if (!isNoWrap(PSE, StridesMap, Ptr, AccessTy, TheLoop)) {
+        auto *Expr = PSE.getSCEV(Ptr);
+        if (!Assume || !isa<SCEVAddRecExpr>(Expr))
+          return false;
+        PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW);
+      }
+    }
+    // If there's only one option for Ptr, look it up after bounds and wrap
+    // checking, because assumptions might have been added to PSE.
+    if (TranslatedPtrs.size() == 1)
+      TranslatedPtrs[0] = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr);
   }
 
-  // The id of the dependence set.
-  unsigned DepId;
+  for (const SCEV *PtrExpr : TranslatedPtrs) {
+    // The id of the dependence set.
+    unsigned DepId;
 
-  if (isDependencyCheckNeeded()) {
-    Value *Leader = DepCands.getLeaderValue(Access).getPointer();
-    unsigned &LeaderId = DepSetId[Leader];
-    if (!LeaderId)
-      LeaderId = RunningDepId++;
-    DepId = LeaderId;
-  } else
-    // Each access has its own dependence set.
-    DepId = RunningDepId++;
+    if (isDependencyCheckNeeded()) {
+      Value *Leader = DepCands.getLeaderValue(Access).getPointer();
+      unsigned &LeaderId = DepSetId[Leader];
+      if (!LeaderId)
+        LeaderId = RunningDepId++;
+      DepId = LeaderId;
+    } else
+      // Each access has its own dependence set.
+      DepId = RunningDepId++;
 
-  bool IsWrite = Access.getInt();
-  RtCheck.insert(TheLoop, Ptr, AccessTy, IsWrite, DepId, ASId, StridesMap, PSE);
-  LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
+    bool IsWrite = Access.getInt();
+    RtCheck.insert(TheLoop, Ptr, PtrExpr, AccessTy, IsWrite, DepId, ASId, PSE);
+    LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n');
+  }
 
   return true;
 }

diff  --git a/llvm/test/Analysis/LoopAccessAnalysis/forked-pointers.ll b/llvm/test/Analysis/LoopAccessAnalysis/forked-pointers.ll
index f0dbc35588896..b63c2f0cc556d 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/forked-pointers.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/forked-pointers.ll
@@ -4,10 +4,32 @@ target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
 
 ; CHECK-LABEL: function 'forked_ptrs_simple':
 ; CHECK-NEXT:  loop:
-; CHECK-NEXT:    Report: cannot identify array bounds
+; CHECK-NEXT:    Memory dependences are safe with run-time checks
 ; CHECK-NEXT:    Dependences:
 ; CHECK-NEXT:    Run-time memory checks:
+; CHECK-NEXT:    Check 0:
+; CHECK-NEXT:      Comparing group ([[G1:.+]]):
+; CHECK-NEXT:        %gep.Dest = getelementptr inbounds float, float* %Dest, i64 %iv
+; CHECK-NEXT:        %gep.Dest = getelementptr inbounds float, float* %Dest, i64 %iv
+; CHECK-NEXT:      Against group ([[G2:.+]]):
+; CHECK-NEXT:        %select = select i1 %cmp, float* %gep.1, float* %gep.2
+; CHECK-NEXT:    Check 1:
+; CHECK-NEXT:      Comparing group ([[G1]]):
+; CHECK-NEXT:        %gep.Dest = getelementptr inbounds float, float* %Dest, i64 %iv
+; CHECK-NEXT:        %gep.Dest = getelementptr inbounds float, float* %Dest, i64 %iv
+; CHECK-NEXT:      Against group ([[G3:.+]]):
+; CHECK-NEXT:        %select = select i1 %cmp, float* %gep.1, float* %gep.2
 ; CHECK-NEXT:    Grouped accesses:
+; CHECK-NEXT:      Group [[G1]]
+; CHECK-NEXT:        (Low: %Dest High: (400 + %Dest))
+; CHECK-NEXT:          Member: {%Dest,+,4}<nuw><%loop>
+; CHECK-NEXT:          Member: {%Dest,+,4}<nuw><%loop>
+; CHECK-NEXT:      Group [[G2]]:
+; CHECK-NEXT:        (Low: %Base1 High: (400 + %Base1))
+; CHECK-NEXT:          Member: {%Base1,+,4}<nw><%loop>
+; CHECK-NEXT:      Group [[G3]]:
+; CHECK-NEXT:        (Low: %Base2 High: (400 + %Base2))
+; CHECK-NEXT:          Member: {%Base2,+,4}<nw><%loop>
 ; CHECK-EMPTY:
 ; CHECK-NEXT:    Non vectorizable stores to invariant address were not found in loop.
 ; CHECK-NEXT:    SCEV assumptions:


        


More information about the llvm-commits mailing list