[llvm-branch-commits] [llvm] be31896 - Fix incorrect SimplifyWithOpReplaced transform (PR47322)

Hans Wennborg via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Sep 15 01:44:04 PDT 2020


Author: Nikita Popov
Date: 2020-09-15T10:21:08+02:00
New Revision: be318969e245db0cd5471bff2a7cbfa3fad2b075

URL: https://github.com/llvm/llvm-project/commit/be318969e245db0cd5471bff2a7cbfa3fad2b075
DIFF: https://github.com/llvm/llvm-project/commit/be318969e245db0cd5471bff2a7cbfa3fad2b075.diff

LOG: Fix incorrect SimplifyWithOpReplaced transform (PR47322)

This is a followup to D86834, which partially fixed this issue in
InstSimplify. However, InstCombine repeats the same transform while
dropping poison flags -- which does not cover cases where poison is
introduced in some other way.

The fix here is a bit more comprehensive, because things are quite
entangled, and it's hard to only partially address it without
regressing optimization. There are really two changes here:

 * Export the SimplifyWithOpReplaced API from InstSimplify, with an
   added AllowRefinement flag. For replacements inside the TrueVal
   we don't actually care whether refinement occurs or not, the
   replacement is always legal. This part of the transform is now
   done in InstSimplify only. (It should be noted that the current
   AllowRefinement check is not sufficient -- that's an issue we
   need to address separately.)
 * Change the InstCombine fold to work by temporarily dropping
   poison generating flags, running the fold and then restoring the
   flags if it didn't work out. This will ensure that the InstCombine
   fold is correct as long as the InstSimplify fold is correct.

Differential Revision: https://reviews.llvm.org/D87445

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/InstructionSimplify.h
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/select.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index 2a39a4e09087..b5ae54fb98bc 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -268,6 +268,12 @@ Value *SimplifyFreezeInst(Value *Op, const SimplifyQuery &Q);
 Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q,
                            OptimizationRemarkEmitter *ORE = nullptr);
 
+/// See if V simplifies when its operand Op is replaced with RepOp.
+/// AllowRefinement specifies whether the simplification can be a refinement,
+/// or whether it needs to be strictly identical.
+Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                              const SimplifyQuery &Q, bool AllowRefinement);
+
 /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively.
 ///
 /// This first performs a normal RAUW of I with SimpleV. It then recursively

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 0a76979f93e2..e744a966a104 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3810,10 +3810,10 @@ Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
 }
 
-/// See if V simplifies when its operand Op is replaced with RepOp.
-static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
-                                           const SimplifyQuery &Q,
-                                           unsigned MaxRecurse) {
+static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                                     const SimplifyQuery &Q,
+                                     bool AllowRefinement,
+                                     unsigned MaxRecurse) {
   // Trivial replacement.
   if (V == Op)
     return RepOp;
@@ -3826,20 +3826,19 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   if (!I)
     return nullptr;
 
+  // Consider:
+  //   %cmp = icmp eq i32 %x, 2147483647
+  //   %add = add nsw i32 %x, 1
+  //   %sel = select i1 %cmp, i32 -2147483648, i32 %add
+  //
+  // We can't replace %sel with %add unless we strip away the flags (which will
+  // be done in InstCombine).
+  // TODO: This is unsound, because it only catches some forms of refinement.
+  if (!AllowRefinement && canCreatePoison(I))
+    return nullptr;
+
   // If this is a binary operator, try to simplify it with the replaced op.
   if (auto *B = dyn_cast<BinaryOperator>(I)) {
-    // Consider:
-    //   %cmp = icmp eq i32 %x, 2147483647
-    //   %add = add nsw i32 %x, 1
-    //   %sel = select i1 %cmp, i32 -2147483648, i32 %add
-    //
-    // We can't replace %sel with %add unless we strip away the flags.
-    // TODO: This is an unusual limitation because better analysis results in
-    //       worse simplification. InstCombine can do this fold more generally
-    //       by dropping the flags. Remove this fold to save compile-time?
-    if (canCreatePoison(I))
-      return nullptr;
-
     if (MaxRecurse) {
       if (B->getOperand(0) == Op)
         return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), Q,
@@ -3906,6 +3905,13 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   return nullptr;
 }
 
+Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                                    const SimplifyQuery &Q,
+                                    bool AllowRefinement) {
+  return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
+                                  RecursionLimit);
+}
+
 /// Try to simplify a select instruction when its condition operand is an
 /// integer comparison where one operand of the compare is a constant.
 static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X,
@@ -4017,14 +4023,18 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   // arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
   if (Pred == ICmpInst::ICMP_EQ) {
-    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
+                               /* AllowRefinement */ false, MaxRecurse) ==
             TrueVal ||
-        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
+                               /* AllowRefinement */ false, MaxRecurse) ==
             TrueVal)
       return FalseVal;
-    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
+                               /* AllowRefinement */ true, MaxRecurse) ==
             FalseVal ||
-        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
+                               /* AllowRefinement */ true, MaxRecurse) ==
             FalseVal)
       return FalseVal;
   }

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index db27711f29b1..fa695c39cd1e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1148,22 +1148,6 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp,
   return &Sel;
 }
 
-static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp,
-                                     const SimplifyQuery &Q) {
-  // If this is a binary operator, try to simplify it with the replaced op
-  // because we know Op and ReplaceOp are equivalant.
-  // For example: V = X + 1, Op = X, ReplaceOp = 42
-  // Simplifies as: add(42, 1) --> 43
-  if (auto *BO = dyn_cast<BinaryOperator>(V)) {
-    if (BO->getOperand(0) == Op)
-      return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q);
-    if (BO->getOperand(1) == Op)
-      return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q);
-  }
-
-  return nullptr;
-}
-
 /// If we have a select with an equality comparison, then we know the value in
 /// one of the arms of the select. See if substituting this value into an arm
 /// and simplifying the result yields the same value as the other arm.
@@ -1190,20 +1174,45 @@ static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp,
   if (Cmp.getPredicate() == ICmpInst::ICMP_NE)
     std::swap(TrueVal, FalseVal);
 
+  auto *FalseInst = dyn_cast<Instruction>(FalseVal);
+  if (!FalseInst)
+    return nullptr;
+
+  // InstSimplify already performed this fold if it was possible subject to
+  // current poison-generating flags. Try the transform again with
+  // poison-generating flags temporarily dropped.
+  bool WasNUW = false, WasNSW = false, WasExact = false;
+  if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
+    WasNUW = OBO->hasNoUnsignedWrap();
+    WasNSW = OBO->hasNoSignedWrap();
+    FalseInst->setHasNoUnsignedWrap(false);
+    FalseInst->setHasNoSignedWrap(false);
+  }
+  if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
+    WasExact = PEO->isExact();
+    FalseInst->setIsExact(false);
+  }
+
   // Try each equivalence substitution possibility.
   // We have an 'EQ' comparison, so the select's false value will propagate.
   // Example:
   // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
-  // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43
   Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
-  if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal ||
-      simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal ||
-      simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal ||
-      simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) {
-    if (auto *FalseInst = dyn_cast<Instruction>(FalseVal))
-      FalseInst->dropPoisonGeneratingFlags();
+  if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
+                             /* AllowRefinement */ false) == TrueVal ||
+      SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
+                             /* AllowRefinement */ false) == TrueVal) {
     return FalseVal;
   }
+
+  // Restore poison-generating flags if the transform did not apply.
+  if (WasNUW)
+    FalseInst->setHasNoUnsignedWrap();
+  if (WasNSW)
+    FalseInst->setHasNoSignedWrap();
+  if (WasExact)
+    FalseInst->setIsExact();
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index abdb36ab7bd4..c23587b606ce 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -2491,9 +2491,11 @@ define <2 x i32> @true_undef_vec(i1 %cond, <2 x i32> %x) {
 ; FIXME: This is a miscompile!
 define i32 @pr47322_more_poisonous_replacement(i32 %arg) {
 ; CHECK-LABEL: @pr47322_more_poisonous_replacement(
-; CHECK-NEXT:    [[TRAILING:%.*]] = call i32 @llvm.cttz.i32(i32 [[ARG:%.*]], i1 immarg true), [[RNG0:!range !.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ARG:%.*]], 0
+; CHECK-NEXT:    [[TRAILING:%.*]] = call i32 @llvm.cttz.i32(i32 [[ARG]], i1 immarg true), [[RNG0:!range !.*]]
 ; CHECK-NEXT:    [[SHIFTED:%.*]] = lshr i32 [[ARG]], [[TRAILING]]
-; CHECK-NEXT:    ret i32 [[SHIFTED]]
+; CHECK-NEXT:    [[R1_SROA_0_1:%.*]] = select i1 [[CMP]], i32 0, i32 [[SHIFTED]]
+; CHECK-NEXT:    ret i32 [[R1_SROA_0_1]]
 ;
   %cmp = icmp eq i32 %arg, 0
   %trailing = call i32 @llvm.cttz.i32(i32 %arg, i1 immarg true)


        


More information about the llvm-branch-commits mailing list