[llvm] 6ffb3ad - [SCEV] Use constant ranges when determining reachable blocks (PR54434)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 18 04:04:43 PDT 2022


Author: Nikita Popov
Date: 2022-03-18T12:04:35+01:00
New Revision: 6ffb3ad631c5071ce82c8b6c73dd1c88e0452944

URL: https://github.com/llvm/llvm-project/commit/6ffb3ad631c5071ce82c8b6c73dd1c88e0452944
DIFF: https://github.com/llvm/llvm-project/commit/6ffb3ad631c5071ce82c8b6c73dd1c88e0452944.diff

LOG: [SCEV] Use constant ranges when determining reachable blocks (PR54434)

This avoids false positive verification failures if the condition
is not literally true/false, but SCEV still makes use of the fact
that a loop is not reachable through more complex reasoning.

Fixes https://github.com/llvm/llvm-project/issues/54434.

Added: 
    llvm/test/Transforms/IndVarSimplify/pr54434.ll

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index cea8f1a756a79..7dbe59513f241 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -2083,6 +2083,11 @@ class ScalarEvolution {
   /// `UniqueSCEVs`.  Return if found, else nullptr.
   SCEV *findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops);
 
+  /// Get reachable blocks in this function, making limited use of SCEV
+  /// reasoning about conditions.
+  void getReachableBlocks(SmallPtrSetImpl<BasicBlock *> &Reachable,
+                          Function &F);
+
   FoldingSet<SCEV> UniqueSCEVs;
   FoldingSet<SCEVPredicate> UniquePreds;
   BumpPtrAllocator SCEVAllocator;

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 89d615dbe24b5..4b716cbc08fe3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13309,8 +13309,8 @@ ScalarEvolution::getUsedLoops(const SCEV *S,
   SCEVTraversal<FindUsedLoops>(F).visitAll(S);
 }
 
-static void getReachableBlocks(SmallPtrSetImpl<BasicBlock *> &Reachable,
-                               Function &F) {
+void ScalarEvolution::getReachableBlocks(
+    SmallPtrSetImpl<BasicBlock *> &Reachable, Function &F) {
   SmallVector<BasicBlock *> Worklist;
   Worklist.push_back(&F.getEntryBlock());
   while (!Worklist.empty()) {
@@ -13318,13 +13318,31 @@ static void getReachableBlocks(SmallPtrSetImpl<BasicBlock *> &Reachable,
     if (!Reachable.insert(BB).second)
       continue;
 
-    const APInt *Cond;
+    Value *Cond;
     BasicBlock *TrueBB, *FalseBB;
-    if (match(BB->getTerminator(),
-              m_Br(m_APInt(Cond), m_BasicBlock(TrueBB), m_BasicBlock(FalseBB))))
-      Worklist.push_back(Cond->isOne() ? TrueBB : FalseBB);
-    else
-      append_range(Worklist, successors(BB));
+    if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
+                                        m_BasicBlock(FalseBB)))) {
+      if (auto *C = dyn_cast<ConstantInt>(Cond)) {
+        Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
+        continue;
+      }
+
+      if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
+        const SCEV *L = getSCEV(Cmp->getOperand(0));
+        const SCEV *R = getSCEV(Cmp->getOperand(1));
+        if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
+          Worklist.push_back(TrueBB);
+          continue;
+        }
+        if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
+                                              R)) {
+          Worklist.push_back(FalseBB);
+          continue;
+        }
+      }
+    }
+
+    append_range(Worklist, successors(BB));
   }
 }
 
@@ -13353,7 +13371,7 @@ void ScalarEvolution::verify() const {
 
   SCEVMapper SCM(SE2);
   SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
-  getReachableBlocks(ReachableBlocks, F);
+  SE2.getReachableBlocks(ReachableBlocks, F);
 
   auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
     if (containsUndefs(Old) || containsUndefs(New)) {

diff  --git a/llvm/test/Transforms/IndVarSimplify/pr54434.ll b/llvm/test/Transforms/IndVarSimplify/pr54434.ll
new file mode 100644
index 0000000000000..7f25c6da1b138
--- /dev/null
+++ b/llvm/test/Transforms/IndVarSimplify/pr54434.ll
@@ -0,0 +1,45 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -indvars -verify-scev < %s | FileCheck %s
+
+define void @test() {
+; CHECK-LABEL: @test(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[FOR_COND:%.*]]
+; CHECK:       for.cond:
+; CHECK-NEXT:    br i1 false, label [[FOR_COND92_PREHEADER:%.*]], label [[FOR_END106:%.*]]
+; CHECK:       for.cond92.preheader:
+; CHECK-NEXT:    br label [[FOR_COND92:%.*]]
+; CHECK:       for.cond92:
+; CHECK-NEXT:    br i1 false, label [[FOR_BODY94:%.*]], label [[FOR_END:%.*]]
+; CHECK:       for.body94:
+; CHECK-NEXT:    br label [[FOR_COND92]]
+; CHECK:       for.end:
+; CHECK-NEXT:    br label [[FOR_COND]]
+; CHECK:       for.end106:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %for.cond
+
+for.cond:                                         ; preds = %for.end, %entry
+  %0 = phi i32 [ %inc105, %for.end ], [ 0, %entry ]
+  %cmp = icmp sge i32 %0, 1
+  br i1 %cmp, label %for.cond92, label %for.end106
+
+for.cond92:                                       ; preds = %for.body94, %for.cond
+  %1 = phi i16 [ %inc, %for.body94 ], [ 0, %for.cond ]
+  %cmp93 = icmp slt i16 %1, 1
+  br i1 %cmp93, label %for.body94, label %for.end
+
+for.body94:                                       ; preds = %for.cond92
+  %inc = add nsw i16 %1, 1
+  br label %for.cond92
+
+for.end:                                          ; preds = %for.cond92
+  %inc105 = add nsw i32 %0, 1
+  br label %for.cond
+
+for.end106:                                       ; preds = %for.cond
+  ret void
+}
+


        


More information about the llvm-commits mailing list