[llvm] 36e2e2e - [InstCombine] Fix incorrect SimplifyWithOpReplaced transform (PR47322)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 12 05:45:19 PDT 2020


Author: Nikita Popov
Date: 2020-09-12T14:45:06+02:00
New Revision: 36e2e2e12efb6b02ad07f502d61b9a95937edb08

URL: https://github.com/llvm/llvm-project/commit/36e2e2e12efb6b02ad07f502d61b9a95937edb08
DIFF: https://github.com/llvm/llvm-project/commit/36e2e2e12efb6b02ad07f502d61b9a95937edb08.diff

LOG: [InstCombine] Fix incorrect SimplifyWithOpReplaced transform (PR47322)

This is a followup to D86834, which partially fixed this issue in
InstSimplify. However, InstCombine repeats the same transform while
dropping poison flags -- which does not cover cases where poison is
introduced in some other way.

The fix here is a bit more comprehensive, because things are quite
entangled, and it's hard to only partially address it without
regressing optimization. There are really two changes here:

 * Export the SimplifyWithOpReplaced API from InstSimplify, with an
   added AllowRefinement flag. For replacements inside the TrueVal
   we don't actually care whether refinement occurs or not, the
   replacement is always legal. This part of the transform is now
   done in InstSimplify only. (It should be noted that the current
   AllowRefinement check is not sufficient -- that's an issue we
   need to address separately.)
 * Change the InstCombine fold to work by temporarily dropping
   poison generating flags, running the fold and then restoring the
   flags if it didn't work out. This will ensure that the InstCombine
   fold is correct as long as the InstSimplify fold is correct.

Differential Revision: https://reviews.llvm.org/D87445

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/InstructionSimplify.h
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/select.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index 6f3d16846621..e0251e7c8bbf 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -292,6 +292,12 @@ Value *SimplifyFreezeInst(Value *Op, const SimplifyQuery &Q);
 Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q,
                            OptimizationRemarkEmitter *ORE = nullptr);
 
+/// See if V simplifies when its operand Op is replaced with RepOp.
+/// AllowRefinement specifies whether the simplification can be a refinement,
+/// or whether it needs to be strictly identical.
+Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                              const SimplifyQuery &Q, bool AllowRefinement);
+
 /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively.
 ///
 /// This first performs a normal RAUW of I with SimpleV. It then recursively

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index e59c0a84044a..f7f5105f9383 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3769,10 +3769,10 @@ Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
 }
 
-/// See if V simplifies when its operand Op is replaced with RepOp.
-static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
-                                           const SimplifyQuery &Q,
-                                           unsigned MaxRecurse) {
+static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                                     const SimplifyQuery &Q,
+                                     bool AllowRefinement,
+                                     unsigned MaxRecurse) {
   // Trivial replacement.
   if (V == Op)
     return RepOp;
@@ -3785,20 +3785,19 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   if (!I)
     return nullptr;
 
+  // Consider:
+  //   %cmp = icmp eq i32 %x, 2147483647
+  //   %add = add nsw i32 %x, 1
+  //   %sel = select i1 %cmp, i32 -2147483648, i32 %add
+  //
+  // We can't replace %sel with %add unless we strip away the flags (which will
+  // be done in InstCombine).
+  // TODO: This is unsound, because it only catches some forms of refinement.
+  if (!AllowRefinement && canCreatePoison(cast<Operator>(I)))
+    return nullptr;
+
   // If this is a binary operator, try to simplify it with the replaced op.
   if (auto *B = dyn_cast<BinaryOperator>(I)) {
-    // Consider:
-    //   %cmp = icmp eq i32 %x, 2147483647
-    //   %add = add nsw i32 %x, 1
-    //   %sel = select i1 %cmp, i32 -2147483648, i32 %add
-    //
-    // We can't replace %sel with %add unless we strip away the flags.
-    // TODO: This is an unusual limitation because better analysis results in
-    //       worse simplification. InstCombine can do this fold more generally
-    //       by dropping the flags. Remove this fold to save compile-time?
-    if (canCreatePoison(cast<Operator>(I)))
-      return nullptr;
-
     if (MaxRecurse) {
       if (B->getOperand(0) == Op)
         return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), Q,
@@ -3865,6 +3864,13 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   return nullptr;
 }
 
+Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                                    const SimplifyQuery &Q,
+                                    bool AllowRefinement) {
+  return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
+                                  RecursionLimit);
+}
+
 /// Try to simplify a select instruction when its condition operand is an
 /// integer comparison where one operand of the compare is a constant.
 static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X,
@@ -3985,14 +3991,18 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   // arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
   if (Pred == ICmpInst::ICMP_EQ) {
-    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
+                               /* AllowRefinement */ false, MaxRecurse) ==
             TrueVal ||
-        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
+                               /* AllowRefinement */ false, MaxRecurse) ==
             TrueVal)
       return FalseVal;
-    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) ==
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
+                               /* AllowRefinement */ true, MaxRecurse) ==
             FalseVal ||
-        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) ==
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
+                               /* AllowRefinement */ true, MaxRecurse) ==
             FalseVal)
       return FalseVal;
   }

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index c05c16b4bdb1..378132011aba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1149,22 +1149,6 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp,
   return &Sel;
 }
 
-static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp,
-                                     const SimplifyQuery &Q) {
-  // If this is a binary operator, try to simplify it with the replaced op
-  // because we know Op and ReplaceOp are equivalant.
-  // For example: V = X + 1, Op = X, ReplaceOp = 42
-  // Simplifies as: add(42, 1) --> 43
-  if (auto *BO = dyn_cast<BinaryOperator>(V)) {
-    if (BO->getOperand(0) == Op)
-      return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q);
-    if (BO->getOperand(1) == Op)
-      return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q);
-  }
-
-  return nullptr;
-}
-
 /// If we have a select with an equality comparison, then we know the value in
 /// one of the arms of the select. See if substituting this value into an arm
 /// and simplifying the result yields the same value as the other arm.
@@ -1191,20 +1175,45 @@ static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp,
   if (Cmp.getPredicate() == ICmpInst::ICMP_NE)
     std::swap(TrueVal, FalseVal);
 
+  auto *FalseInst = dyn_cast<Instruction>(FalseVal);
+  if (!FalseInst)
+    return nullptr;
+
+  // InstSimplify already performed this fold if it was possible subject to
+  // current poison-generating flags. Try the transform again with
+  // poison-generating flags temporarily dropped.
+  bool WasNUW = false, WasNSW = false, WasExact = false;
+  if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
+    WasNUW = OBO->hasNoUnsignedWrap();
+    WasNSW = OBO->hasNoSignedWrap();
+    FalseInst->setHasNoUnsignedWrap(false);
+    FalseInst->setHasNoSignedWrap(false);
+  }
+  if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
+    WasExact = PEO->isExact();
+    FalseInst->setIsExact(false);
+  }
+
   // Try each equivalence substitution possibility.
   // We have an 'EQ' comparison, so the select's false value will propagate.
   // Example:
   // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
-  // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43
   Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
-  if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal ||
-      simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal ||
-      simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal ||
-      simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) {
-    if (auto *FalseInst = dyn_cast<Instruction>(FalseVal))
-      FalseInst->dropPoisonGeneratingFlags();
+  if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
+                             /* AllowRefinement */ false) == TrueVal ||
+      SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
+                             /* AllowRefinement */ false) == TrueVal) {
     return FalseVal;
   }
+
+  // Restore poison-generating flags if the transform did not apply.
+  if (WasNUW)
+    FalseInst->setHasNoUnsignedWrap();
+  if (WasNSW)
+    FalseInst->setHasNoSignedWrap();
+  if (WasExact)
+    FalseInst->setIsExact();
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 570f92866d89..d9a4f4bdbd47 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -2588,12 +2588,13 @@ define void @select_freeze_icmp_multuses(i32 %x, i32 %y) {
   ret void
 }
 
-; FIXME: This is a miscompile!
 define i32 @pr47322_more_poisonous_replacement(i32 %arg) {
 ; CHECK-LABEL: @pr47322_more_poisonous_replacement(
-; CHECK-NEXT:    [[TRAILING:%.*]] = call i32 @llvm.cttz.i32(i32 [[ARG:%.*]], i1 immarg true), [[RNG0:!range !.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[ARG:%.*]], 0
+; CHECK-NEXT:    [[TRAILING:%.*]] = call i32 @llvm.cttz.i32(i32 [[ARG]], i1 immarg true), [[RNG0:!range !.*]]
 ; CHECK-NEXT:    [[SHIFTED:%.*]] = lshr i32 [[ARG]], [[TRAILING]]
-; CHECK-NEXT:    ret i32 [[SHIFTED]]
+; CHECK-NEXT:    [[R1_SROA_0_1:%.*]] = select i1 [[CMP]], i32 0, i32 [[SHIFTED]]
+; CHECK-NEXT:    ret i32 [[R1_SROA_0_1]]
 ;
   %cmp = icmp eq i32 %arg, 0
   %trailing = call i32 @llvm.cttz.i32(i32 %arg, i1 immarg true)


        


More information about the llvm-commits mailing list