[llvm] 734ee0e - [LVI] Support using block values when handling conditions (#75311)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 2 01:49:49 PST 2024


Author: Nikita Popov
Date: 2024-01-02T10:49:45+01:00
New Revision: 734ee0e01feeadd75bdbed35acc08f242623a212

URL: https://github.com/llvm/llvm-project/commit/734ee0e01feeadd75bdbed35acc08f242623a212
DIFF: https://github.com/llvm/llvm-project/commit/734ee0e01feeadd75bdbed35acc08f242623a212.diff

LOG: [LVI] Support using block values when handling conditions (#75311)

Currently, LVI will only use conditions like "X < C" to constrain the
value of X on the relevant edge. This patch extends it to handle
conditions like "X < Y" by querying the known range of Y.

This means that getValueFromCondition() and various related APIs can now
return nullopt to indicate that they have pushed to the worklist, and
need to be called again later. This behavior is currently controlled by
a UseBlockValue option, and only enabled for actual edge value handling.
All other places deriving constraints from conditions keep using the
previous logic for now.

This change was originally motivated as a fix for the regression
reported in
https://github.com/llvm/llvm-project/pull/73662#issuecomment-1849281758.
Unfortunately, it doesn't actually fix it, because we run into another
issue there (LVI currently is really bad at handling values used in
loops).

This change has some compile-time impact, but it's fairly small,
in the 0.05% range.

Added: 
    

Modified: 
    llvm/lib/Analysis/LazyValueInfo.cpp
    llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index 89cc7ea15ec1d7..f7d87716482278 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -434,6 +434,28 @@ class LazyValueInfoImpl {
 
   void solve();
 
+  // For the following methods, if UseBlockValue is true, the function may
+  // push additional values to the worklist and return nullopt. If
+  // UseBlockValue is false, it will never return nullopt.
+
+  std::optional<ValueLatticeElement>
+  getValueFromSimpleICmpCondition(CmpInst::Predicate Pred, Value *RHS,
+                                  const APInt &Offset, Instruction *CxtI,
+                                  bool UseBlockValue);
+
+  std::optional<ValueLatticeElement>
+  getValueFromICmpCondition(Value *Val, ICmpInst *ICI, bool isTrueDest,
+                            bool UseBlockValue);
+
+  std::optional<ValueLatticeElement>
+  getValueFromCondition(Value *Val, Value *Cond, bool IsTrueDest,
+                        bool UseBlockValue, unsigned Depth = 0);
+
+  std::optional<ValueLatticeElement> getEdgeValueLocal(Value *Val,
+                                                       BasicBlock *BBFrom,
+                                                       BasicBlock *BBTo,
+                                                       bool UseBlockValue);
+
 public:
   /// This is the query interface to determine the lattice value for the
   /// specified Value* at the context instruction (if specified) or at the
@@ -755,14 +777,10 @@ LazyValueInfoImpl::solveBlockValuePHINode(PHINode *PN, BasicBlock *BB) {
   return Result;
 }
 
-static ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond,
-                                                 bool isTrueDest = true,
-                                                 unsigned Depth = 0);
-
 // If we can determine a constraint on the value given conditions assumed by
 // the program, intersect those constraints with BBLV
 void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
-        Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
+    Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
   BBI = BBI ? BBI : dyn_cast<Instruction>(Val);
   if (!BBI)
     return;
@@ -779,17 +797,21 @@ void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
     if (I->getParent() != BB || !isValidAssumeForContext(I, BBI))
       continue;
 
-    BBLV = intersect(BBLV, getValueFromCondition(Val, I->getArgOperand(0)));
+    BBLV = intersect(BBLV, *getValueFromCondition(Val, I->getArgOperand(0),
+                                                  /*IsTrueDest*/ true,
+                                                  /*UseBlockValue*/ false));
   }
 
   // If guards are not used in the module, don't spend time looking for them
   if (GuardDecl && !GuardDecl->use_empty() &&
       BBI->getIterator() != BB->begin()) {
-    for (Instruction &I : make_range(std::next(BBI->getIterator().getReverse()),
-                                     BB->rend())) {
+    for (Instruction &I :
+         make_range(std::next(BBI->getIterator().getReverse()), BB->rend())) {
       Value *Cond = nullptr;
       if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(Cond))))
-        BBLV = intersect(BBLV, getValueFromCondition(Val, Cond));
+        BBLV = intersect(BBLV,
+                         *getValueFromCondition(Val, Cond, /*IsTrueDest*/ true,
+                                                /*UseBlockValue*/ false));
     }
   }
 
@@ -886,10 +908,14 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) {
   // If the value is undef, a 
diff erent value may be chosen in
   // the select condition.
   if (isGuaranteedNotToBeUndef(Cond, AC)) {
-    TrueVal = intersect(TrueVal,
-                        getValueFromCondition(SI->getTrueValue(), Cond, true));
-    FalseVal = intersect(
-        FalseVal, getValueFromCondition(SI->getFalseValue(), Cond, false));
+    TrueVal =
+        intersect(TrueVal, *getValueFromCondition(SI->getTrueValue(), Cond,
+                                                  /*IsTrueDest*/ true,
+                                                  /*UseBlockValue*/ false));
+    FalseVal =
+        intersect(FalseVal, *getValueFromCondition(SI->getFalseValue(), Cond,
+                                                   /*IsTrueDest*/ false,
+                                                   /*UseBlockValue*/ false));
   }
 
   ValueLatticeElement Result = TrueVal;
@@ -1068,15 +1094,26 @@ static bool matchICmpOperand(APInt &Offset, Value *LHS, Value *Val,
 }
 
 /// Get value range for a "(Val + Offset) Pred RHS" condition.
-static ValueLatticeElement getValueFromSimpleICmpCondition(
-    CmpInst::Predicate Pred, Value *RHS, const APInt &Offset) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getValueFromSimpleICmpCondition(CmpInst::Predicate Pred,
+                                                   Value *RHS,
+                                                   const APInt &Offset,
+                                                   Instruction *CxtI,
+                                                   bool UseBlockValue) {
   ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(),
                          /*isFullSet=*/true);
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS))
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
     RHSRange = ConstantRange(CI->getValue());
-  else if (Instruction *I = dyn_cast<Instruction>(RHS))
+  } else if (UseBlockValue) {
+    std::optional<ValueLatticeElement> R =
+        getBlockValue(RHS, CxtI->getParent(), CxtI);
+    if (!R)
+      return std::nullopt;
+    RHSRange = toConstantRange(*R, RHS->getType());
+  } else if (Instruction *I = dyn_cast<Instruction>(RHS)) {
     if (auto *Ranges = I->getMetadata(LLVMContext::MD_range))
       RHSRange = getConstantRangeFromMetadata(*Ranges);
+  }
 
   ConstantRange TrueValues =
       ConstantRange::makeAllowedICmpRegion(Pred, RHSRange);
@@ -1103,8 +1140,8 @@ getRangeViaSLT(CmpInst::Predicate Pred, APInt RHS,
   return std::nullopt;
 }
 
-static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
-                                                     bool isTrueDest) {
+std::optional<ValueLatticeElement> LazyValueInfoImpl::getValueFromICmpCondition(
+    Value *Val, ICmpInst *ICI, bool isTrueDest, bool UseBlockValue) {
   Value *LHS = ICI->getOperand(0);
   Value *RHS = ICI->getOperand(1);
 
@@ -1128,11 +1165,13 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
   unsigned BitWidth = Ty->getScalarSizeInBits();
   APInt Offset(BitWidth, 0);
   if (matchICmpOperand(Offset, LHS, Val, EdgePred))
-    return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset);
+    return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI,
+                                           UseBlockValue);
 
   CmpInst::Predicate SwappedPred = CmpInst::getSwappedPredicate(EdgePred);
   if (matchICmpOperand(Offset, RHS, Val, SwappedPred))
-    return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset);
+    return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset, ICI,
+                                           UseBlockValue);
 
   const APInt *Mask, *C;
   if (match(LHS, m_And(m_Specific(Val), m_APInt(Mask))) &&
@@ -1212,10 +1251,12 @@ static ValueLatticeElement getValueFromOverflowCondition(
   return ValueLatticeElement::getRange(NWR);
 }
 
-static ValueLatticeElement getValueFromCondition(
-    Value *Val, Value *Cond, bool IsTrueDest, unsigned Depth) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getValueFromCondition(Value *Val, Value *Cond,
+                                         bool IsTrueDest, bool UseBlockValue,
+                                         unsigned Depth) {
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cond))
-    return getValueFromICmpCondition(Val, ICI, IsTrueDest);
+    return getValueFromICmpCondition(Val, ICI, IsTrueDest, UseBlockValue);
 
   if (auto *EVI = dyn_cast<ExtractValueInst>(Cond))
     if (auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand()))
@@ -1227,7 +1268,7 @@ static ValueLatticeElement getValueFromCondition(
 
   Value *N;
   if (match(Cond, m_Not(m_Value(N))))
-    return getValueFromCondition(Val, N, !IsTrueDest, Depth);
+    return getValueFromCondition(Val, N, !IsTrueDest, UseBlockValue, Depth);
 
   Value *L, *R;
   bool IsAnd;
@@ -1238,19 +1279,23 @@ static ValueLatticeElement getValueFromCondition(
   else
     return ValueLatticeElement::getOverdefined();
 
-  ValueLatticeElement LV = getValueFromCondition(Val, L, IsTrueDest, Depth);
-  ValueLatticeElement RV = getValueFromCondition(Val, R, IsTrueDest, Depth);
+  std::optional<ValueLatticeElement> LV =
+      getValueFromCondition(Val, L, IsTrueDest, UseBlockValue, Depth);
+  std::optional<ValueLatticeElement> RV =
+      getValueFromCondition(Val, R, IsTrueDest, UseBlockValue, Depth);
+  if (!LV || !RV)
+    return std::nullopt;
 
   // if (L && R) -> intersect L and R
   // if (!(L || R)) -> intersect !L and !R
   // if (L || R) -> union L and R
   // if (!(L && R)) -> union !L and !R
   if (IsTrueDest ^ IsAnd) {
-    LV.mergeIn(RV);
-    return LV;
+    LV->mergeIn(*RV);
+    return *LV;
   }
 
-  return intersect(LV, RV);
+  return intersect(*LV, *RV);
 }
 
 // Return true if Usr has Op as an operand, otherwise false.
@@ -1302,8 +1347,9 @@ static ValueLatticeElement constantFoldUser(User *Usr, Value *Op,
 }
 
 /// Compute the value of Val on the edge BBFrom -> BBTo.
-static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
-                                             BasicBlock *BBTo) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
+                                     BasicBlock *BBTo, bool UseBlockValue) {
   // TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we
   // know that v != 0.
   if (BranchInst *BI = dyn_cast<BranchInst>(BBFrom->getTerminator())) {
@@ -1324,13 +1370,16 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
 
       // If the condition of the branch is an equality comparison, we may be
       // able to infer the value.
-      ValueLatticeElement Result = getValueFromCondition(Val, Condition,
-                                                         isTrueDest);
-      if (!Result.isOverdefined())
+      std::optional<ValueLatticeElement> Result =
+          getValueFromCondition(Val, Condition, isTrueDest, UseBlockValue);
+      if (!Result)
+        return std::nullopt;
+
+      if (!Result->isOverdefined())
         return Result;
 
       if (User *Usr = dyn_cast<User>(Val)) {
-        assert(Result.isOverdefined() && "Result isn't overdefined");
+        assert(Result->isOverdefined() && "Result isn't overdefined");
         // Check with isOperationFoldable() first to avoid linearly iterating
         // over the operands unnecessarily which can be expensive for
         // instructions with many operands.
@@ -1356,8 +1405,8 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
             //    br i1 %Condition, label %then, label %else
             for (unsigned i = 0; i < Usr->getNumOperands(); ++i) {
               Value *Op = Usr->getOperand(i);
-              ValueLatticeElement OpLatticeVal =
-                  getValueFromCondition(Op, Condition, isTrueDest);
+              ValueLatticeElement OpLatticeVal = *getValueFromCondition(
+                  Op, Condition, isTrueDest, /*UseBlockValue*/ false);
               if (std::optional<APInt> OpConst =
                       OpLatticeVal.asConstantInteger()) {
                 Result = constantFoldUser(Usr, Op, *OpConst, DL);
@@ -1367,7 +1416,7 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
           }
         }
       }
-      if (!Result.isOverdefined())
+      if (!Result->isOverdefined())
         return Result;
     }
   }
@@ -1432,8 +1481,12 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
   if (Constant *VC = dyn_cast<Constant>(Val))
     return ValueLatticeElement::get(VC);
 
-  ValueLatticeElement LocalResult = getEdgeValueLocal(Val, BBFrom, BBTo);
-  if (hasSingleValue(LocalResult))
+  std::optional<ValueLatticeElement> LocalResult =
+      getEdgeValueLocal(Val, BBFrom, BBTo, /*UseBlockValue*/ true);
+  if (!LocalResult)
+    return std::nullopt;
+
+  if (hasSingleValue(*LocalResult))
     // Can't get any more precise here
     return LocalResult;
 
@@ -1453,7 +1506,7 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
   // but then the result is not cached.
   intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, CxtI);
 
-  return intersect(LocalResult, InBlock);
+  return intersect(*LocalResult, InBlock);
 }
 
 ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB,
@@ -1499,10 +1552,12 @@ getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB,
 
   std::optional<ValueLatticeElement> Result =
       getEdgeValue(V, FromBB, ToBB, CxtI);
-  if (!Result) {
+  while (!Result) {
+    // As the worklist only explicitly tracks block values (but not edge values)
+    // we may have to call solve() multiple times, as the edge value calculation
+    // may request additional block values.
     solve();
     Result = getEdgeValue(V, FromBB, ToBB, CxtI);
-    assert(Result && "More work to do after problem solved?");
   }
 
   LLVM_DEBUG(dbgs() << "  Result = " << *Result << "\n");
@@ -1528,13 +1583,17 @@ ValueLatticeElement LazyValueInfoImpl::getValueAtUse(const Use &U) {
       if (!isGuaranteedNotToBeUndef(SI->getCondition(), AC))
         break;
       if (CurrU->getOperandNo() == 1)
-        CondVal = getValueFromCondition(V, SI->getCondition(), true);
+        CondVal =
+            *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ true,
+                                   /*UseBlockValue*/ false);
       else if (CurrU->getOperandNo() == 2)
-        CondVal = getValueFromCondition(V, SI->getCondition(), false);
+        CondVal =
+            *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ false,
+                                   /*UseBlockValue*/ false);
     } else if (auto *PHI = dyn_cast<PHINode>(CurrI)) {
       // TODO: Use non-local query?
-      CondVal =
-          getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU), PHI->getParent());
+      CondVal = *getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU),
+                                   PHI->getParent(), /*UseBlockValue*/ false);
     }
     if (CondVal)
       VL = intersect(VL, *CondVal);

diff  --git a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
index d30b31d317a6de..252f6596cedc5e 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
@@ -12,8 +12,7 @@ define void @test_icmp_from_implied_cond(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L2:%.*]], label [[END]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[B_CMP1:%.*]] = icmp ult i32 [[B]], 32
-; CHECK-NEXT:    call void @use(i1 [[B_CMP1]])
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    [[B_CMP2:%.*]] = icmp ult i32 [[B]], 31
 ; CHECK-NEXT:    call void @use(i1 [[B_CMP2]])
 ; CHECK-NEXT:    ret void
@@ -47,7 +46,7 @@ define i64 @test_sext_from_implied_cond(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L2:%.*]], label [[END]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[SEXT:%.*]] = sext i32 [[B]] to i64
+; CHECK-NEXT:    [[SEXT:%.*]] = zext nneg i32 [[B]] to i64
 ; CHECK-NEXT:    ret i64 [[SEXT]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i64 0
@@ -74,8 +73,7 @@ define void @test_icmp_from_implied_range(i16 %x, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L1:%.*]], label [[END:%.*]]
 ; CHECK:       l1:
-; CHECK-NEXT:    [[B_CMP1:%.*]] = icmp ult i32 [[B]], 65535
-; CHECK-NEXT:    call void @use(i1 [[B_CMP1]])
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    [[B_CMP2:%.*]] = icmp ult i32 [[B]], 65534
 ; CHECK-NEXT:    call void @use(i1 [[B_CMP2]])
 ; CHECK-NEXT:    ret void


        


More information about the llvm-commits mailing list