[llvm] c41b4b6 - [InstCombine] Make flag drop during select equiv fold more generic

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 19 05:55:45 PDT 2023


Author: Nikita Popov
Date: 2023-09-19T14:54:25+02:00
New Revision: c41b4b6397675c0d75b316aac3bae380f1ac493c

URL: https://github.com/llvm/llvm-project/commit/c41b4b6397675c0d75b316aac3bae380f1ac493c
DIFF: https://github.com/llvm/llvm-project/commit/c41b4b6397675c0d75b316aac3bae380f1ac493c.diff

LOG: [InstCombine] Make flag drop during select equiv fold more generic

Instead of unsetting flags on the instruction, attempting the
fold, and the resetting the flags if it failed, add support to
simplifyWithOpReplaced() to ignore poison-generating flags/metadata
and collect all instructions where they may need to be dropped.

This allows us to perform the fold a) with poison-generating
metadata, which was previously not handled and b) poison-generating
flags/metadata that are not on the root instruction.

Proof for the ctpop case: https://alive2.llvm.org/ce/z/3H3HFs

Fixes https://github.com/llvm/llvm-project/issues/62450.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/InstructionSimplify.h
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/bit_ceil.ll
    llvm/test/Transforms/InstCombine/ctpop.ll
    llvm/test/Transforms/InstCombine/ispow2.ll
    llvm/test/Transforms/LoopVectorize/reduction-inloop.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index df0784664eadc86..401119b3cb85247 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -339,8 +339,14 @@ simplifyInstructionWithOperands(Instruction *I, ArrayRef<Value *> NewOps,
 /// AllowRefinement specifies whether the simplification can be a refinement
 /// (e.g. 0 instead of poison), or whether it needs to be strictly identical.
 /// Op and RepOp can be assumed to not be poison when determining refinement.
-Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
-                              const SimplifyQuery &Q, bool AllowRefinement);
+///
+/// If DropFlags is passed, then the replacement result is only valid if
+/// poison-generating flags/metadata on those instructions are dropped. This
+/// is only useful in conjunction with AllowRefinement=false.
+Value *
+simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                       const SimplifyQuery &Q, bool AllowRefinement,
+                       SmallVectorImpl<Instruction *> *DropFlags = nullptr);
 
 /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively.
 ///

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index e8f96e9f681f2d5..d8aa614cae53b10 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4299,6 +4299,7 @@ Value *llvm::simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
 static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
                                      const SimplifyQuery &Q,
                                      bool AllowRefinement,
+                                     SmallVectorImpl<Instruction *> *DropFlags,
                                      unsigned MaxRecurse) {
   // Trivial replacement.
   if (V == Op)
@@ -4333,7 +4334,7 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   bool AnyReplaced = false;
   for (Value *InstOp : I->operands()) {
     if (Value *NewInstOp = simplifyWithOpReplaced(
-            InstOp, Op, RepOp, Q, AllowRefinement, MaxRecurse)) {
+            InstOp, Op, RepOp, Q, AllowRefinement, DropFlags, MaxRecurse)) {
       NewOps.push_back(NewInstOp);
       AnyReplaced = InstOp != NewInstOp;
     } else {
@@ -4427,16 +4428,23 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   // will be done in InstCombine).
   // TODO: This may be unsound, because it only catches some forms of
   // refinement.
-  if (!AllowRefinement && canCreatePoison(cast<Operator>(I)))
-    return nullptr;
+  if (!AllowRefinement) {
+    if (canCreatePoison(cast<Operator>(I), !DropFlags))
+      return nullptr;
+    Constant *Res = ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
+    if (DropFlags && Res && I->hasPoisonGeneratingFlagsOrMetadata())
+      DropFlags->push_back(I);
+    return Res;
+  }
 
   return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
 }
 
 Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
                                     const SimplifyQuery &Q,
-                                    bool AllowRefinement) {
-  return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
+                                    bool AllowRefinement,
+                                    SmallVectorImpl<Instruction *> *DropFlags) {
+  return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, DropFlags,
                                   RecursionLimit);
 }
 
@@ -4569,11 +4577,11 @@ static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS,
                                        unsigned MaxRecurse) {
   if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
                              /* AllowRefinement */ false,
-                             MaxRecurse) == TrueVal)
+                             /* DropFlags */ nullptr, MaxRecurse) == TrueVal)
     return FalseVal;
   if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
                              /* AllowRefinement */ true,
-                             MaxRecurse) == FalseVal)
+                             /* DropFlags */ nullptr, MaxRecurse) == FalseVal)
     return FalseVal;
 
   return nullptr;

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 291e6382f898c5b..05b3eaacc7b4d06 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1309,45 +1309,28 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     return nullptr;
 
   // InstSimplify already performed this fold if it was possible subject to
-  // current poison-generating flags. Try the transform again with
-  // poison-generating flags temporarily dropped.
-  bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false;
-  if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) {
-    WasNUW = OBO->hasNoUnsignedWrap();
-    WasNSW = OBO->hasNoSignedWrap();
-    FalseInst->setHasNoUnsignedWrap(false);
-    FalseInst->setHasNoSignedWrap(false);
-  }
-  if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) {
-    WasExact = PEO->isExact();
-    FalseInst->setIsExact(false);
-  }
-  if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) {
-    WasInBounds = GEP->isInBounds();
-    GEP->setIsInBounds(false);
-  }
+  // current poison-generating flags. Check whether dropping poison-generating
+  // flags enables the transform.
 
   // Try each equivalence substitution possibility.
   // We have an 'EQ' comparison, so the select's false value will propagate.
   // Example:
   // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
+  SmallVector<Instruction *> DropFlags;
   if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
-                             /* AllowRefinement */ false) == TrueVal ||
+                             /* AllowRefinement */ false,
+                             &DropFlags) == TrueVal ||
       simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
-                             /* AllowRefinement */ false) == TrueVal) {
+                             /* AllowRefinement */ false,
+                             &DropFlags) == TrueVal) {
+    for (Instruction *I : DropFlags) {
+      I->dropPoisonGeneratingFlagsAndMetadata();
+      Worklist.add(I);
+    }
+
     return replaceInstUsesWith(Sel, FalseVal);
   }
 
-  // Restore poison-generating flags if the transform did not apply.
-  if (WasNUW)
-    FalseInst->setHasNoUnsignedWrap();
-  if (WasNSW)
-    FalseInst->setHasNoSignedWrap();
-  if (WasExact)
-    FalseInst->setIsExact();
-  if (WasInBounds)
-    cast<GetElementPtrInst>(FalseInst)->setIsInBounds();
-
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstCombine/bit_ceil.ll b/llvm/test/Transforms/InstCombine/bit_ceil.ll
index 6f714153a598ada..52e70c78ba54289 100644
--- a/llvm/test/Transforms/InstCombine/bit_ceil.ll
+++ b/llvm/test/Transforms/InstCombine/bit_ceil.ll
@@ -148,10 +148,9 @@ define i32 @bit_ceil_commuted_operands(i32 %x) {
 ; CHECK-LABEL: @bit_ceil_commuted_operands(
 ; CHECK-NEXT:    [[DEC:%.*]] = add i32 [[X:%.*]], -1
 ; CHECK-NEXT:    [[CTLZ:%.*]] = tail call i32 @llvm.ctlz.i32(i32 [[DEC]], i1 false), !range [[RNG0]]
-; CHECK-NEXT:    [[TMP1:%.*]] = sub nsw i32 0, [[CTLZ]]
-; CHECK-NEXT:    [[TMP2:%.*]] = and i32 [[TMP1]], 31
-; CHECK-NEXT:    [[SEL:%.*]] = shl nuw i32 1, [[TMP2]]
-; CHECK-NEXT:    ret i32 [[SEL]]
+; CHECK-NEXT:    [[SUB:%.*]] = sub nuw nsw i32 32, [[CTLZ]]
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 1, [[SUB]]
+; CHECK-NEXT:    ret i32 [[SHL]]
 ;
   %dec = add i32 %x, -1
   %ctlz = tail call i32 @llvm.ctlz.i32(i32 %dec, i1 false)

diff  --git a/llvm/test/Transforms/InstCombine/ctpop.ll b/llvm/test/Transforms/InstCombine/ctpop.ll
index a3db7715d183453..f3419768bbd0285 100644
--- a/llvm/test/Transforms/InstCombine/ctpop.ll
+++ b/llvm/test/Transforms/InstCombine/ctpop.ll
@@ -479,9 +479,7 @@ define i32 @parity_xor_extra_use2(i32 %arg, i32 %arg1) {
 define i32 @select_ctpop_zero(i32 %x) {
 ; CHECK-LABEL: @select_ctpop_zero(
 ; CHECK-NEXT:    [[CTPOP:%.*]] = call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG1]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[X]], 0
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[CMP]], i32 0, i32 [[CTPOP]]
-; CHECK-NEXT:    ret i32 [[RES]]
+; CHECK-NEXT:    ret i32 [[CTPOP]]
 ;
   %ctpop = call i32 @llvm.ctpop.i32(i32 %x)
   %cmp = icmp eq i32 %x, 0

diff  --git a/llvm/test/Transforms/InstCombine/ispow2.ll b/llvm/test/Transforms/InstCombine/ispow2.ll
index a67a3fde1e661d8..1fa0d4c4af054b3 100644
--- a/llvm/test/Transforms/InstCombine/ispow2.ll
+++ b/llvm/test/Transforms/InstCombine/ispow2.ll
@@ -345,9 +345,7 @@ define i1 @is_pow2_ctpop_wrong_pred1_logical(i32 %x) {
 ; CHECK-LABEL: @is_pow2_ctpop_wrong_pred1_logical(
 ; CHECK-NEXT:    [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]]
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[T0]], 2
-; CHECK-NEXT:    [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0
-; CHECK-NEXT:    [[R:%.*]] = select i1 [[NOTZERO]], i1 [[CMP]], i1 false
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %t0 = tail call i32 @llvm.ctpop.i32(i32 %x)
   %cmp = icmp ugt i32 %t0, 2

diff  --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index aca02f37abe2ef3..18b05c05d9b9d21 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1122,12 +1122,10 @@ define i32 @predicated_not_dominates_reduction(ptr nocapture noundef readonly %h
 ; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[INDEX]] to i64
 ; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i8, ptr [[H:%.*]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP1]], align 1
-; CHECK-NEXT:    [[DOTNOT:%.*]] = icmp eq <4 x i8> [[WIDE_LOAD]], zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = udiv <4 x i8> [[WIDE_LOAD]], <i8 31, i8 31, i8 31, i8 31>
 ; CHECK-NEXT:    [[TMP3:%.*]] = shl nuw nsw <4 x i8> [[TMP2]], <i8 3, i8 3, i8 3, i8 3>
 ; CHECK-NEXT:    [[TMP4:%.*]] = udiv <4 x i8> [[TMP3]], <i8 31, i8 31, i8 31, i8 31>
-; CHECK-NEXT:    [[NARROW:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i8> zeroinitializer, <4 x i8> [[TMP4]]
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i8> [[NARROW]] to <4 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = zext <4 x i8> [[TMP4]] to <4 x i32>
 ; CHECK-NEXT:    [[TMP6:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP5]])
 ; CHECK-NEXT:    [[TMP7]] = add i32 [[TMP6]], [[VEC_PHI]]
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i32 [[INDEX]], 4


        


More information about the llvm-commits mailing list