[llvm] 11484cb - [InstCombine] Pass SimplifyQuery to SimplifyDemandedBits()

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 1 03:41:30 PDT 2024


Author: Nikita Popov
Date: 2024-07-01T12:41:21+02:00
New Revision: 11484cb817bcc2a6e2ef9572be982a1a5a4964ec

URL: https://github.com/llvm/llvm-project/commit/11484cb817bcc2a6e2ef9572be982a1a5a4964ec
DIFF: https://github.com/llvm/llvm-project/commit/11484cb817bcc2a6e2ef9572be982a1a5a4964ec.diff

LOG: [InstCombine] Pass SimplifyQuery to SimplifyDemandedBits()

This will enable calling SimplifyDemandedBits() with a SimplifyQuery
that has CondContext set in the future.

Additionally this also marginally strengthens the analysis by
retaining the original context instruction for one-use chains.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
    llvm/test/Transforms/InstCombine/known-bits.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 855d1aeddfaee..ebcbd5d9e8880 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -504,7 +504,14 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
 
   virtual bool SimplifyDemandedBits(Instruction *I, unsigned OpNo,
                                     const APInt &DemandedMask, KnownBits &Known,
-                                    unsigned Depth = 0) = 0;
+                                    unsigned Depth, const SimplifyQuery &Q) = 0;
+
+  bool SimplifyDemandedBits(Instruction *I, unsigned OpNo,
+                            const APInt &DemandedMask, KnownBits &Known) {
+    return SimplifyDemandedBits(I, OpNo, DemandedMask, Known,
+                                /*Depth=*/0, SQ.getWithInstruction(I));
+  }
+
   virtual Value *
   SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts,
                              unsigned Depth = 0,

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
index 4dc4d28724ef9..b5ca045058cbc 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
@@ -187,7 +187,7 @@ ARMTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     }
     KnownBits ScalarKnown(32);
     if (IC.SimplifyDemandedBits(&II, 0, APInt::getLowBitsSet(32, 16),
-                                ScalarKnown, 0)) {
+                                ScalarKnown)) {
       return ⅈ
     }
     break;

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 93157a64bfe3f..abadf54a96767 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6451,13 +6451,13 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
     // Don't use dominating conditions when folding icmp using known bits. This
     // may convert signed into unsigned predicates in ways that other passes
     // (especially IndVarSimplify) may not be able to reliably undo.
-    SQ.DC = nullptr;
-    auto _ = make_scope_exit([&]() { SQ.DC = &DC; });
+    SimplifyQuery Q = SQ.getWithoutDomCondCache().getWithInstruction(&I);
     if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth),
-                             Op0Known, 0))
+                             Op0Known, /*Depth=*/0, Q))
       return &I;
 
-    if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0))
+    if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known,
+                             /*Depth=*/0, Q))
       return &I;
   }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 984f02bcccad7..318c455fd7ef1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -548,18 +548,19 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   /// Attempts to replace V with a simpler value based on the demanded
   /// bits.
   Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits &Known,
-                                 unsigned Depth, Instruction *CxtI);
+                                 unsigned Depth, const SimplifyQuery &Q);
+  using InstCombiner::SimplifyDemandedBits;
   bool SimplifyDemandedBits(Instruction *I, unsigned Op,
                             const APInt &DemandedMask, KnownBits &Known,
-                            unsigned Depth = 0) override;
+                            unsigned Depth, const SimplifyQuery &Q) override;
 
   /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne
   /// bits. It also tries to handle simplifications that can be done based on
   /// DemandedMask, but without modifying the Instruction.
   Value *SimplifyMultipleUseDemandedBits(Instruction *I,
                                          const APInt &DemandedMask,
-                                         KnownBits &Known,
-                                         unsigned Depth, Instruction *CxtI);
+                                         KnownBits &Known, unsigned Depth,
+                                         const SimplifyQuery &Q);
 
   /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded
   /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence.

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index bfe5db3547cd5..02003150d85da 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -69,7 +69,7 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst,
                                                        KnownBits &Known) {
   APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth()));
   Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known,
-                                     0, &Inst);
+                                     0, SQ.getWithInstruction(&Inst));
   if (!V) return false;
   if (V == &Inst) return true;
   replaceInstUsesWith(Inst, V);
@@ -88,10 +88,11 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
 /// change and false otherwise.
 bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
                                             const APInt &DemandedMask,
-                                            KnownBits &Known, unsigned Depth) {
+                                            KnownBits &Known, unsigned Depth,
+                                            const SimplifyQuery &Q) {
   Use &U = I->getOperandUse(OpNo);
   Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, Known,
-                                          Depth, I);
+                                          Depth, Q);
   if (!NewVal) return false;
   if (Instruction* OpInst = dyn_cast<Instruction>(U))
     salvageDebugInfo(*OpInst);
@@ -126,7 +127,7 @@ bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
 Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
                                                  KnownBits &Known,
                                                  unsigned Depth,
-                                                 Instruction *CxtI) {
+                                                 const SimplifyQuery &Q) {
   assert(V != nullptr && "Null pointer of Value???");
   assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
   uint32_t BitWidth = DemandedMask.getBitWidth();
@@ -137,7 +138,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       "Value *V, DemandedMask and Known must have same BitWidth");
 
   if (isa<Constant>(V)) {
-    computeKnownBits(V, Known, Depth, CxtI);
+    llvm::computeKnownBits(V, Known, Depth, Q);
     return nullptr;
   }
 
@@ -150,7 +151,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
   Instruction *I = dyn_cast<Instruction>(V);
   if (!I) {
-    computeKnownBits(V, Known, Depth, CxtI);
+    llvm::computeKnownBits(V, Known, Depth, Q);
     return nullptr;        // Only analyze instructions.
   }
 
@@ -158,7 +159,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
   // we can't do any simplifications of the operands, because DemandedMask
   // only reflects the bits demanded by *one* of the users.
   if (Depth != 0 && !I->hasOneUse())
-    return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI);
+    return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, Q);
 
   KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);
   // If this is the root being simplified, allow it to have multiple uses,
@@ -190,9 +191,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     // significant bit and all those below it.
     DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
     if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
-        SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) ||
+        SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1, Q) ||
         ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
-        SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) {
+        SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q)) {
       disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
       return true;
     }
@@ -201,17 +202,17 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
   switch (I->getOpcode()) {
   default:
-    computeKnownBits(I, Known, Depth, CxtI);
+    llvm::computeKnownBits(I, Known, Depth, Q);
     break;
   case Instruction::And: {
     // If either the LHS or the RHS are Zero, the result is zero.
-    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
+    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) ||
         SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.Zero, LHSKnown,
-                             Depth + 1))
+                             Depth + 1, Q))
       return I;
 
     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
-                                         Depth, SQ.getWithInstruction(CxtI));
+                                         Depth, Q);
 
     // If the client is only demanding bits that we know, return the known
     // constant.
@@ -233,16 +234,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
   }
   case Instruction::Or: {
     // If either the LHS or the RHS are One, the result is One.
-    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
+    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) ||
         SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown,
-                             Depth + 1)) {
+                             Depth + 1, Q)) {
       // Disjoint flag may not longer hold.
       I->dropPoisonGeneratingFlags();
       return I;
     }
 
     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
-                                         Depth, SQ.getWithInstruction(CxtI));
+                                         Depth, Q);
 
     // If the client is only demanding bits that we know, return the known
     // constant.
@@ -264,7 +265,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) {
       WithCache<const Value *> LHSCache(I->getOperand(0), LHSKnown),
           RHSCache(I->getOperand(1), RHSKnown);
-      if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(I))) {
+      if (haveNoCommonBitsSet(LHSCache, RHSCache, Q)) {
         cast<PossiblyDisjointInst>(I)->setIsDisjoint(true);
         return I;
       }
@@ -273,8 +274,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     break;
   }
   case Instruction::Xor: {
-    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) ||
-        SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1))
+    if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1, Q) ||
+        SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q))
       return I;
     Value *LHS, *RHS;
     if (DemandedMask == 1 &&
@@ -288,7 +289,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     }
 
     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
-                                         Depth, SQ.getWithInstruction(CxtI));
+                                         Depth, Q);
 
     // If the client is only demanding bits that we know, return the known
     // constant.
@@ -365,8 +366,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     break;
   }
   case Instruction::Select: {
-    if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) ||
-        SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1))
+    if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1, Q) ||
+        SimplifyDemandedBits(I, 1, DemandedMask, LHSKnown, Depth + 1, Q))
       return I;
 
     // If the operands are constants, see if we can simplify them.
@@ -434,7 +435,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
     APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth);
     KnownBits InputKnown(SrcBitWidth);
-    if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) {
+    if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1,
+                             Q)) {
       // For zext nneg, we may have dropped the instruction which made the
       // input non-negative.
       I->dropPoisonGeneratingFlags();
@@ -460,7 +462,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       InputDemandedBits.setBit(SrcBitWidth-1);
 
     KnownBits InputKnown(SrcBitWidth);
-    if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1))
+    if (SimplifyDemandedBits(I, 0, InputDemandedBits, InputKnown, Depth + 1, Q))
       return I;
 
     // If the input sign bit is known zero, or if the NewBits are not demanded
@@ -521,7 +523,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     unsigned NLZ = DemandedMask.countl_zero();
     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
     if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
-        SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
+        SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q))
       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
 
     // If low order bits are not demanded and known to be zero in one operand,
@@ -531,7 +533,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     APInt DemandedFromLHS = DemandedFromOps;
     DemandedFromLHS.clearLowBits(NTZ);
     if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
-        SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1))
+        SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q))
       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
 
     // If we are known to be adding zeros to every bit below
@@ -564,7 +566,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     unsigned NLZ = DemandedMask.countl_zero();
     APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
     if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
-        SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
+        SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1, Q))
       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
 
     // If low order bits are not demanded and are known to be zero in RHS,
@@ -574,7 +576,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     APInt DemandedFromLHS = DemandedFromOps;
     DemandedFromLHS.clearLowBits(NTZ);
     if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
-        SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1))
+        SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1, Q))
       return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
 
     // If we are known to be subtracting zeros from every bit below
@@ -618,7 +620,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       return InsertNewInstWith(And1, I->getIterator());
     }
 
-    computeKnownBits(I, Known, Depth, CxtI);
+    llvm::computeKnownBits(I, Known, Depth, Q);
     break;
   }
   case Instruction::Shl: {
@@ -639,7 +641,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
             auto [IID, FShiftArgs] = *Opt;
             if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
                 FShiftArgs[0] == FShiftArgs[1]) {
-              computeKnownBits(I, Known, Depth, CxtI);
+              llvm::computeKnownBits(I, Known, Depth, Q);
               break;
             }
           }
@@ -653,7 +655,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
         if (I->hasNoSignedWrap()) {
           unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
           unsigned SignBits =
-              ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
+              ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI);
           if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumHiDemandedBits)
             return I->getOperand(0);
         }
@@ -685,7 +687,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       else if (IOp->hasNoUnsignedWrap())
         DemandedMaskIn.setHighBits(ShiftAmt);
 
-      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1))
+      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q))
         return I;
 
       Known = KnownBits::shl(Known,
@@ -698,13 +700,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       // demanding those bits from the pre-shifted operand either.
       if (unsigned CTLZ = DemandedMask.countl_zero()) {
         APInt DemandedFromOp(APInt::getLowBitsSet(BitWidth, BitWidth - CTLZ));
-        if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1)) {
+        if (SimplifyDemandedBits(I, 0, DemandedFromOp, Known, Depth + 1, Q)) {
           // We can't guarantee that nsw/nuw hold after simplifying the operand.
           I->dropPoisonGeneratingFlags();
           return I;
         }
       }
-      computeKnownBits(I, Known, Depth, CxtI);
+      llvm::computeKnownBits(I, Known, Depth, Q);
     }
     break;
   }
@@ -721,7 +723,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
             auto [IID, FShiftArgs] = *Opt;
             if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
                 FShiftArgs[0] == FShiftArgs[1]) {
-              computeKnownBits(I, Known, Depth, CxtI);
+              llvm::computeKnownBits(I, Known, Depth, Q);
               break;
             }
           }
@@ -735,7 +737,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
         // need to shift.
         unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
         unsigned SignBits =
-            ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
+            ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI);
         if (SignBits >= NumHiDemandedBits)
           return I->getOperand(0);
 
@@ -759,7 +761,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
       // Unsigned shift right.
       APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
-      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) {
+      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) {
         // exact flag may not longer hold.
         I->dropPoisonGeneratingFlags();
         return I;
@@ -769,12 +771,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       if (ShiftAmt)
         Known.Zero.setHighBits(ShiftAmt);  // high bits known zero.
     } else {
-      computeKnownBits(I, Known, Depth, CxtI);
+      llvm::computeKnownBits(I, Known, Depth, Q);
     }
     break;
   }
   case Instruction::AShr: {
-    unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
+    unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, Q.CxtI);
 
     // If we only want bits that already match the signbit then we don't need
     // to shift.
@@ -804,7 +806,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       if (DemandedMask.countl_zero() <= ShiftAmt)
         DemandedMaskIn.setSignBit();
 
-      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) {
+      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1, Q)) {
         // exact flag may not longer hold.
         I->dropPoisonGeneratingFlags();
         return I;
@@ -833,7 +835,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
         Known.Zero &= ~HighBits;
       }
     } else {
-      computeKnownBits(I, Known, Depth, CxtI);
+      llvm::computeKnownBits(I, Known, Depth, Q);
     }
     break;
   }
@@ -845,7 +847,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       unsigned RHSTrailingZeros = SA->countr_zero();
       APInt DemandedMaskIn =
           APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros);
-      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) {
+      if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1, Q)) {
         // We can't guarantee that "exact" is still true after changing the
         // the dividend.
         I->dropPoisonGeneratingFlags();
@@ -855,7 +857,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       Known = KnownBits::udiv(LHSKnown, KnownBits::makeConstant(*SA),
                               cast<BinaryOperator>(I)->isExact());
     } else {
-      computeKnownBits(I, Known, Depth, CxtI);
+      llvm::computeKnownBits(I, Known, Depth, Q);
     }
     break;
   }
@@ -873,7 +875,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 
         APInt LowBits = RA - 1;
         APInt Mask2 = LowBits | APInt::getSignMask(BitWidth);
-        if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1))
+        if (SimplifyDemandedBits(I, 0, Mask2, LHSKnown, Depth + 1, Q))
           return I;
 
         // The low bits of LHS are unchanged by the srem.
@@ -894,7 +896,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
       }
     }
 
-    computeKnownBits(I, Known, Depth, CxtI);
+    llvm::computeKnownBits(I, Known, Depth, Q);
     break;
   }
   case Instruction::Call: {
@@ -950,10 +952,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
         unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits();
         RHSKnown = KnownBits(MaskWidth);
         // If either the LHS or the RHS are Zero, the result is zero.
-        if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) ||
+        if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1, Q) ||
             SimplifyDemandedBits(
                 I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth),
-                RHSKnown, Depth + 1))
+                RHSKnown, Depth + 1, Q))
           return I;
 
         // TODO: Should be 1-extend
@@ -1040,8 +1042,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
         APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt));
         if (I->getOperand(0) != I->getOperand(1)) {
           if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown,
-                                   Depth + 1) ||
-              SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1))
+                                   Depth + 1, Q) ||
+              SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1,
+                                   Q))
             return I;
         } else { // fshl is a rotate
           // Avoid converting rotate into funnel shift.
@@ -1103,7 +1106,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     }
 
     if (!KnownBitsComputed)
-      computeKnownBits(V, Known, Depth, CxtI);
+      llvm::computeKnownBits(V, Known, Depth, Q);
     break;
   }
   }
@@ -1121,7 +1124,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
     return Constant::getIntegerValue(VTy, Known.One);
 
   if (VerifyKnownBits) {
-    KnownBits ReferenceKnown = computeKnownBits(V, Depth, CxtI);
+    KnownBits ReferenceKnown = llvm::computeKnownBits(V, Depth, Q);
     if (Known != ReferenceKnown) {
       errs() << "Mismatched known bits for " << *V << " in "
              << I->getFunction()->getName() << "\n";
@@ -1139,7 +1142,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
 /// DemandedMask, but without modifying the Instruction.
 Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
     Instruction *I, const APInt &DemandedMask, KnownBits &Known, unsigned Depth,
-    Instruction *CxtI) {
+    const SimplifyQuery &Q) {
   unsigned BitWidth = DemandedMask.getBitWidth();
   Type *ITy = I->getType();
 
@@ -1152,11 +1155,11 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
   // this instruction has a simpler value in that context.
   switch (I->getOpcode()) {
   case Instruction::And: {
-    computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
-    computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
+    llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
+    llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
-                                         Depth, SQ.getWithInstruction(CxtI));
-    computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
+                                         Depth, Q);
+    computeKnownBitsFromContext(I, Known, Depth, Q);
 
     // If the client is only demanding bits that we know, return the known
     // constant.
@@ -1173,11 +1176,11 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
     break;
   }
   case Instruction::Or: {
-    computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
-    computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
+    llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
+    llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
-                                         Depth, SQ.getWithInstruction(CxtI));
-    computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
+                                         Depth, Q);
+    computeKnownBitsFromContext(I, Known, Depth, Q);
 
     // If the client is only demanding bits that we know, return the known
     // constant.
@@ -1196,11 +1199,11 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
     break;
   }
   case Instruction::Xor: {
-    computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
-    computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
+    llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
+    llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
     Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown,
-                                         Depth, SQ.getWithInstruction(CxtI));
-    computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
+                                         Depth, Q);
+    computeKnownBitsFromContext(I, Known, Depth, Q);
 
     // If the client is only demanding bits that we know, return the known
     // constant.
@@ -1223,11 +1226,11 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
 
     // If an operand adds zeros to every bit below the highest demanded bit,
     // that operand doesn't change the result. Return the other side.
-    computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
+    llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
       return I->getOperand(0);
 
-    computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
+    llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
     if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
       return I->getOperand(1);
 
@@ -1235,7 +1238,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
     Known =
         KnownBits::computeForAddSub(/*Add=*/true, NSW, NUW, LHSKnown, RHSKnown);
-    computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
+    computeKnownBitsFromContext(I, Known, Depth, Q);
     break;
   }
   case Instruction::Sub: {
@@ -1244,21 +1247,21 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
 
     // If an operand subtracts zeros from every bit below the highest demanded
     // bit, that operand doesn't change the result. Return the other side.
-    computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
+    llvm::computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, Q);
     if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
       return I->getOperand(0);
 
     bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
     bool NUW = cast<OverflowingBinaryOperator>(I)->hasNoUnsignedWrap();
-    computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
+    llvm::computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, Q);
     Known = KnownBits::computeForAddSub(/*Add=*/false, NSW, NUW, LHSKnown,
                                         RHSKnown);
-    computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI));
+    computeKnownBitsFromContext(I, Known, Depth, Q);
     break;
   }
   case Instruction::AShr: {
     // Compute the Known bits to simplify things downstream.
-    computeKnownBits(I, Known, Depth, CxtI);
+    llvm::computeKnownBits(I, Known, Depth, Q);
 
     // If this user is only demanding bits that we know, return the known
     // constant.
@@ -1285,7 +1288,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
   }
   default:
     // Compute the Known bits to simplify things downstream.
-    computeKnownBits(I, Known, Depth, CxtI);
+    llvm::computeKnownBits(I, Known, Depth, Q);
 
     // If this user is only demanding bits that we know, return the known
     // constant.

diff  --git a/llvm/test/Transforms/InstCombine/known-bits.ll b/llvm/test/Transforms/InstCombine/known-bits.ll
index dafc37db0086e..d7a8386552067 100644
--- a/llvm/test/Transforms/InstCombine/known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/known-bits.ll
@@ -2010,13 +2010,11 @@ if.else:
 
 define i8 @simplifydemanded_context(i8 %x, i8 %y) {
 ; CHECK-LABEL: @simplifydemanded_context(
-; CHECK-NEXT:    [[AND1:%.*]] = and i8 [[X:%.*]], 1
 ; CHECK-NEXT:    call void @dummy()
-; CHECK-NEXT:    [[X_LOBITS:%.*]] = and i8 [[X]], 3
+; CHECK-NEXT:    [[X_LOBITS:%.*]] = and i8 [[X:%.*]], 3
 ; CHECK-NEXT:    [[PRECOND:%.*]] = icmp eq i8 [[X_LOBITS]], 0
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[PRECOND]])
-; CHECK-NEXT:    [[AND2:%.*]] = and i8 [[AND1]], [[Y:%.*]]
-; CHECK-NEXT:    ret i8 [[AND2]]
+; CHECK-NEXT:    ret i8 0
 ;
   %and1 = and i8 %x, 1
   call void @dummy() ; may unwind


        


More information about the llvm-commits mailing list