[llvm] 80852a4 - [SCEV] Prove implications of different type via truncation

Max Kazantsev via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 20 22:53:46 PDT 2020


Author: Max Kazantsev
Date: 2020-10-21T12:53:22+07:00
New Revision: 80852a4f2fb154c6094bb9d9e3457757d5a60ad1

URL: https://github.com/llvm/llvm-project/commit/80852a4f2fb154c6094bb9d9e3457757d5a60ad1
DIFF: https://github.com/llvm/llvm-project/commit/80852a4f2fb154c6094bb9d9e3457757d5a60ad1.diff

LOG: [SCEV] Prove implications of different type via truncation

When we need to prove implication of expressions of different type width,
the default strategy is to widen everything to wider type and prove in this
type. This does not interact well with AddRecs with negative steps and
unsigned predicates: such AddRec will likely not have a `nuw` flag, and its
`zext` to wider type will not be an AddRec. In contraty, `trunc` of an AddRec
in some cases can easily be proved to be an `AddRec` too.

This patch introduces an alternative way to handling implications of different
type widths. If we can prove that wider type values actually fit in the narrow type,
we truncate them and prove the implication in narrow type.

Differential Revision: https://reviews.llvm.org/D89548
Reviewed By: fhahn

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/srem.ll
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index efc4600e248f..6e351a53628f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9699,6 +9699,25 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
   // Balance the types.
   if (getTypeSizeInBits(LHS->getType()) <
       getTypeSizeInBits(FoundLHS->getType())) {
+    // For unsigned and equality predicates, try to prove that both found
+    // operands fit into narrow unsigned range. If so, try to prove facts in
+    // narrow types.
+    if (!CmpInst::isSigned(FoundPred)) {
+      auto *NarrowType = LHS->getType();
+      auto *WideType = FoundLHS->getType();
+      auto BitWidth = getTypeSizeInBits(NarrowType);
+      const SCEV *MaxValue = getZeroExtendExpr(
+          getConstant(APInt::getMaxValue(BitWidth)), WideType);
+      if (isKnownPredicate(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) &&
+          isKnownPredicate(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) {
+        const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
+        const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
+        if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
+                                       TruncFoundRHS, Context))
+          return true;
+      }
+    }
+
     if (CmpInst::isSigned(Pred)) {
       LHS = getSignExtendExpr(LHS, FoundLHS->getType());
       RHS = getSignExtendExpr(RHS, FoundLHS->getType());

diff  --git a/llvm/test/Analysis/ScalarEvolution/srem.ll b/llvm/test/Analysis/ScalarEvolution/srem.ll
index 197437b51ca1..76e0d4a5ec5b 100644
--- a/llvm/test/Analysis/ScalarEvolution/srem.ll
+++ b/llvm/test/Analysis/ScalarEvolution/srem.ll
@@ -29,7 +29,7 @@ define dso_local void @_Z4loopi(i32 %width) local_unnamed_addr #0 {
 ; CHECK-NEXT:    %add = add nsw i32 %2, %call
 ; CHECK-NEXT:    --> (%2 + %call) U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.cond: Variant }
 ; CHECK-NEXT:    %inc = add nsw i32 %i.0, 1
-; CHECK-NEXT:    --> {1,+,1}<nuw><%for.cond> U: [1,0) S: [1,0) Exits: (1 + %width) LoopDispositions: { %for.cond: Computable }
+; CHECK-NEXT:    --> {1,+,1}<nuw><%for.cond> U: full-set S: full-set Exits: (1 + %width) LoopDispositions: { %for.cond: Computable }
 ; CHECK-NEXT:  Determining loop execution counts for: @_Z4loopi
 ; CHECK-NEXT:  Loop %for.cond: backedge-taken count is %width
 ; CHECK-NEXT:  Loop %for.cond: max backedge-taken count is -1

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index be8941838f71..ee70fe5e7ce5 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1316,4 +1316,45 @@ TEST_F(ScalarEvolutionsTest, UnsignedIsImpliedViaOperations) {
   });
 }
 
+TEST_F(ScalarEvolutionsTest, ProveImplicationViaNarrowing) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(
+      "define i32 @foo(i32 %start, i32* %q) { "
+      "entry: "
+      "  %wide.start = zext i32 %start to i64 "
+      "  br label %loop "
+      "loop: "
+      "  %wide.iv = phi i64 [%wide.start, %entry], [%wide.iv.next, %backedge] "
+      "  %iv = phi i32 [%start, %entry], [%iv.next, %backedge] "
+      "  %cond = icmp eq i64 %wide.iv, 0 "
+      "  br i1 %cond, label %exit, label %backedge "
+      "backedge: "
+      "  %iv.next = add i32 %iv, -1 "
+      "  %index = zext i32 %iv.next to i64 "
+      "  %load.addr = getelementptr i32, i32* %q, i64 %index "
+      "  %stop = load i32, i32* %load.addr "
+      "  %loop.cond = icmp eq i32 %stop, 0 "
+      "  %wide.iv.next = add nsw i64 %wide.iv, -1 "
+      "  br i1 %loop.cond, label %loop, label %failure "
+      "exit: "
+      "  ret i32 0 "
+      "failure: "
+      "  unreachable "
+      "} ",
+      Err, C);
+
+  ASSERT_TRUE(M && "Could not parse module?");
+  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    auto *IV = SE.getSCEV(getInstructionByName(F, "iv"));
+    auto *Zero = SE.getZero(IV->getType());
+    auto *Backedge = getInstructionByName(F, "iv.next")->getParent();
+    ASSERT_TRUE(Backedge);
+    EXPECT_TRUE(SE.isBasicBlockEntryGuardedByCond(Backedge, ICmpInst::ICMP_UGT,
+                                                  IV, Zero));
+  });
+}
+
 }  // end namespace llvm


        


More information about the llvm-commits mailing list