[llvm] 69acdfe - [SCEV] Prove implicaitons via AddRec start

Benjamin Kramer via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 5 08:56:34 PDT 2020


I'm seeing a bunch of miscompiles after this change. Sadly it's
extremely subtle, so the executable test case is somewhat big. It
boils down to an extra `nuw` flag being added after this change in a
place where it's not safe.

$ diff -u before-no-unroll.ll after-no-unroll.ll
--- before-no-unroll.ll 2020-10-05 17:34:40.028478035 +0200
+++ after-no-unroll.ll  2020-10-05 17:33:52.900457475 +0200
@@ -153,7 +153,7 @@
   %.not69 = phi i1 [ false,
%convolution-window-dilated.inner.loop_exit.iz.us.us ], [ true,
%convolution-window-dilated.inner.loop_header.k2.preheader.us ]
   %convolution-window-dilated.inner.invar_address.k2.06.us.us = phi
i64 [ 1, %convolution-window-dilated.inner.loop_exit.iz.us.us ], [ 0,
%convolution-window-dilated.inner.loop_header.k2.preheader.us ]
   %44 = shl nuw nsw i64
%convolution-window-dilated.inner.invar_address.k2.06.us.us, 1
-  %45 = add nsw i64 %36, %44
+  %45 = add nuw nsw i64 %36, %44
   %46 = icmp ult i64 %45, 2
   br i1 %46, label %in-bounds-true.us.us.us, label
%convolution-window-dilated.inner.loop_exit.iz.us.us

The nuw is invalid as %36 can be -1. There's also an executable
version attached.

On Thu, Oct 1, 2020 at 12:10 PM Max Kazantsev via llvm-commits
<llvm-commits at lists.llvm.org> wrote:
>
>
> Author: Max Kazantsev
> Date: 2020-10-01T17:09:38+07:00
> New Revision: 69acdfe075fa8eb18781f88f4d0cd1ea40fa6e48
>
> URL: https://github.com/llvm/llvm-project/commit/69acdfe075fa8eb18781f88f4d0cd1ea40fa6e48
> DIFF: https://github.com/llvm/llvm-project/commit/69acdfe075fa8eb18781f88f4d0cd1ea40fa6e48.diff
>
> LOG: [SCEV] Prove implicaitons via AddRec start
>
> If we know that some predicate is true for AddRec and an invariant
> (w.r.t. this AddRec's loop), this fact is, in particular, true on the first
> iteration. We can try to prove the facts we need using the start value.
>
> The motivating example is proving things like
> ```
>   isImpliedCondOperands(>=, X, 0, {X,+,-1}, 0}
> ```
>
> Differential Revision: https://reviews.llvm.org/D88208
> Reviewed By: reames
>
> Added:
>
>
> Modified:
>     llvm/include/llvm/Analysis/ScalarEvolution.h
>     llvm/lib/Analysis/ScalarEvolution.cpp
>     llvm/unittests/Analysis/ScalarEvolutionTest.cpp
>
> Removed:
>
>
>
> ################################################################################
> diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
> index febca473776a..158257a5aa9a 100644
> --- a/llvm/include/llvm/Analysis/ScalarEvolution.h
> +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
> @@ -1677,23 +1677,30 @@ class ScalarEvolution {
>    getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) const;
>
>    /// Test whether the condition described by Pred, LHS, and RHS is true
> -  /// whenever the given FoundCondValue value evaluates to true.
> +  /// whenever the given FoundCondValue value evaluates to true in given
> +  /// Context. If Context is nullptr, then the found predicate is true
> +  /// everywhere.
>    bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
> -                     const Value *FoundCondValue, bool Inverse);
> +                     const Value *FoundCondValue, bool Inverse,
> +                     const Instruction *Context = nullptr);
>
>    /// Test whether the condition described by Pred, LHS, and RHS is true
>    /// whenever the condition described by FoundPred, FoundLHS, FoundRHS is
> -  /// true.
> +  /// true in given Context. If Context is nullptr, then the found predicate is
> +  /// true everywhere.
>    bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
>                       ICmpInst::Predicate FoundPred, const SCEV *FoundLHS,
> -                     const SCEV *FoundRHS);
> +                     const SCEV *FoundRHS,
> +                     const Instruction *Context = nullptr);
>
>    /// Test whether the condition described by Pred, LHS, and RHS is true
>    /// whenever the condition described by Pred, FoundLHS, and FoundRHS is
> -  /// true.
> +  /// true in given Context. If Context is nullptr, then the found predicate is
> +  /// true everywhere.
>    bool isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS,
>                               const SCEV *RHS, const SCEV *FoundLHS,
> -                             const SCEV *FoundRHS);
> +                             const SCEV *FoundRHS,
> +                             const Instruction *Context = nullptr);
>
>    /// Test whether the condition described by Pred, LHS, and RHS is true
>    /// whenever the condition described by Pred, FoundLHS, and FoundRHS is
> @@ -1740,6 +1747,18 @@ class ScalarEvolution {
>                                            const SCEV *FoundLHS,
>                                            const SCEV *FoundRHS);
>
> +  /// Test whether the condition described by Pred, LHS, and RHS is true
> +  /// whenever the condition described by Pred, FoundLHS, and FoundRHS is
> +  /// true.
> +  ///
> +  /// This routine tries to weaken the known condition basing on fact that
> +  /// FoundLHS is an AddRec.
> +  bool isImpliedCondOperandsViaAddRecStart(ICmpInst::Predicate Pred,
> +                                           const SCEV *LHS, const SCEV *RHS,
> +                                           const SCEV *FoundLHS,
> +                                           const SCEV *FoundRHS,
> +                                           const Instruction *Context);
> +
>    /// Test whether the condition described by Pred, LHS, and RHS is true
>    /// whenever the condition described by Pred, FoundLHS, and FoundRHS is
>    /// true.
>
> diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
> index e51b31673105..a3e454fefcf0 100644
> --- a/llvm/lib/Analysis/ScalarEvolution.cpp
> +++ b/llvm/lib/Analysis/ScalarEvolution.cpp
> @@ -9549,15 +9549,16 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
>
>    // Try to prove (Pred, LHS, RHS) using isImpliedCond.
>    auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
> -    if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse))
> +    const Instruction *Context = &BB->front();
> +    if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, Context))
>        return true;
>      if (ProvingStrictComparison) {
>        if (!ProvedNonStrictComparison)
> -        ProvedNonStrictComparison =
> -            isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse);
> +        ProvedNonStrictComparison = isImpliedCond(NonStrictPredicate, LHS, RHS,
> +                                                  Condition, Inverse, Context);
>        if (!ProvedNonEquality)
> -        ProvedNonEquality =
> -            isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse);
> +        ProvedNonEquality = isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS,
> +                                          Condition, Inverse, Context);
>        if (ProvedNonStrictComparison && ProvedNonEquality)
>          return true;
>      }
> @@ -9623,7 +9624,8 @@ bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
>
>  bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>                                      const SCEV *RHS,
> -                                    const Value *FoundCondValue, bool Inverse) {
> +                                    const Value *FoundCondValue, bool Inverse,
> +                                    const Instruction *Context) {
>    if (!PendingLoopPredicates.insert(FoundCondValue).second)
>      return false;
>
> @@ -9634,12 +9636,16 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>    if (const BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
>      if (BO->getOpcode() == Instruction::And) {
>        if (!Inverse)
> -        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
> -               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
> +        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
> +                             Context) ||
> +               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
> +                             Context);
>      } else if (BO->getOpcode() == Instruction::Or) {
>        if (Inverse)
> -        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
> -               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
> +        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
> +                             Context) ||
> +               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
> +                             Context);
>      }
>    }
>
> @@ -9657,14 +9663,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>    const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
>    const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
>
> -  return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS);
> +  return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context);
>  }
>
>  bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>                                      const SCEV *RHS,
>                                      ICmpInst::Predicate FoundPred,
> -                                    const SCEV *FoundLHS,
> -                                    const SCEV *FoundRHS) {
> +                                    const SCEV *FoundLHS, const SCEV *FoundRHS,
> +                                    const Instruction *Context) {
>    // Balance the types.
>    if (getTypeSizeInBits(LHS->getType()) <
>        getTypeSizeInBits(FoundLHS->getType())) {
> @@ -9708,16 +9714,16 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>
>    // Check whether the found predicate is the same as the desired predicate.
>    if (FoundPred == Pred)
> -    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
> +    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
>
>    // Check whether swapping the found predicate makes it the same as the
>    // desired predicate.
>    if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
>      if (isa<SCEVConstant>(RHS))
> -      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
> +      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context);
>      else
> -      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
> -                                   RHS, LHS, FoundLHS, FoundRHS);
> +      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS,
> +                                   LHS, FoundLHS, FoundRHS, Context);
>    }
>
>    // Unsigned comparison is the same as signed comparison when both the operands
> @@ -9725,7 +9731,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>    if (CmpInst::isUnsigned(FoundPred) &&
>        CmpInst::getSignedPredicate(FoundPred) == Pred &&
>        isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
> -    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
> +    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
>
>    // Check if we can make progress by sharpening ranges.
>    if (FoundPred == ICmpInst::ICMP_NE &&
> @@ -9762,8 +9768,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>          case ICmpInst::ICMP_UGE:
>            // We know V `Pred` SharperMin.  If this implies LHS `Pred`
>            // RHS, we're done.
> -          if (isImpliedCondOperands(Pred, LHS, RHS, V,
> -                                    getConstant(SharperMin)))
> +          if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
> +                                    Context))
>              return true;
>            LLVM_FALLTHROUGH;
>
> @@ -9778,7 +9784,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>            //
>            // If V `Pred` Min implies LHS `Pred` RHS, we're done.
>
> -          if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
> +          if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min),
> +                                    Context))
>              return true;
>            break;
>
> @@ -9786,14 +9793,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>          case ICmpInst::ICMP_SLE:
>          case ICmpInst::ICMP_ULE:
>            if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
> -                                    LHS, V, getConstant(SharperMin)))
> +                                    LHS, V, getConstant(SharperMin), Context))
>              return true;
>            LLVM_FALLTHROUGH;
>
>          case ICmpInst::ICMP_SLT:
>          case ICmpInst::ICMP_ULT:
>            if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
> -                                    LHS, V, getConstant(Min)))
> +                                    LHS, V, getConstant(Min), Context))
>              return true;
>            break;
>
> @@ -9807,11 +9814,12 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
>    // Check whether the actual condition is beyond sufficient.
>    if (FoundPred == ICmpInst::ICMP_EQ)
>      if (ICmpInst::isTrueWhenEqual(Pred))
> -      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
> +      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context))
>          return true;
>    if (Pred == ICmpInst::ICMP_NE)
>      if (!ICmpInst::isTrueWhenEqual(FoundPred))
> -      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
> +      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS,
> +                                Context))
>          return true;
>
>    // Otherwise assume the worst.
> @@ -9890,6 +9898,44 @@ Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
>    return None;
>  }
>
> +bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
> +    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
> +    const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *Context) {
> +  // Try to recognize the following pattern:
> +  //
> +  //   FoundRHS = ...
> +  // ...
> +  // loop:
> +  //   FoundLHS = {Start,+,W}
> +  // context_bb: // Basic block from the same loop
> +  //   known(Pred, FoundLHS, FoundRHS)
> +  //
> +  // If some predicate is known in the context of a loop, it is also known on
> +  // each iteration of this loop, including the first iteration. Therefore, in
> +  // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
> +  // prove the original pred using this fact.
> +  if (!Context)
> +    return false;
> +  // Make sure AR varies in the context block.
> +  if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
> +    if (!AR->getLoop()->contains(Context->getParent()))
> +      return false;
> +    if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
> +      return false;
> +    return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
> +  }
> +
> +  if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
> +    if (!AR->getLoop()->contains(Context))
> +      return false;
> +    if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
> +      return false;
> +    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
> +  }
> +
> +  return false;
> +}
> +
>  bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
>      ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
>      const SCEV *FoundLHS, const SCEV *FoundRHS) {
> @@ -10080,13 +10126,18 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
>  bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
>                                              const SCEV *LHS, const SCEV *RHS,
>                                              const SCEV *FoundLHS,
> -                                            const SCEV *FoundRHS) {
> +                                            const SCEV *FoundRHS,
> +                                            const Instruction *Context) {
>    if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
>      return true;
>
>    if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
>      return true;
>
> +  if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
> +                                          Context))
> +    return true;
> +
>    return isImpliedCondOperandsHelper(Pred, LHS, RHS,
>                                       FoundLHS, FoundRHS) ||
>           // ~x < ~y --> x > y
>
> diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
> index ff33495f2271..e5ffc21fb664 100644
> --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
> +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
> @@ -1251,4 +1251,36 @@ TEST_F(ScalarEvolutionsTest, SCEVgetExitLimitForGuardedLoop) {
>    });
>  }
>
> +TEST_F(ScalarEvolutionsTest, ImpliedViaAddRecStart) {
> +  LLVMContext C;
> +  SMDiagnostic Err;
> +  std::unique_ptr<Module> M = parseAssemblyString(
> +      "define void @foo(i32* %p) { "
> +      "entry: "
> +      "  %x = load i32, i32* %p, !range !0 "
> +      "  br label %loop "
> +      "loop: "
> +      "  %iv = phi i32 [ %x, %entry], [%iv.next, %backedge] "
> +      "  %ne.check = icmp ne i32 %iv, 0 "
> +      "  br i1 %ne.check, label %backedge, label %exit "
> +      "backedge: "
> +      "  %iv.next = add i32 %iv, -1 "
> +      "  br label %loop "
> +      "exit:"
> +      "  ret void "
> +      "} "
> +      "!0 = !{i32 0, i32 2147483647}",
> +      Err, C);
> +
> +  ASSERT_TRUE(M && "Could not parse module?");
> +  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
> +
> +  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
> +    auto *X = SE.getSCEV(getInstructionByName(F, "x"));
> +    auto *Context = getInstructionByName(F, "iv.next");
> +    EXPECT_TRUE(SE.isKnownPredicateAt(ICmpInst::ICMP_NE, X,
> +                                      SE.getZero(X->getType()), Context));
> +  });
> +}
> +
>  }  // end namespace llvm
>
>
>
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits
-------------- next part --------------
A non-text attachment was scrubbed...
Name: addrec.tar.gz
Type: application/x-gzip
Size: 14150 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20201005/2d0724ca/attachment.bin>


More information about the llvm-commits mailing list