[llvm] [LVI] Support using block values when handling conditions (PR #75311)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 20 01:29:58 PST 2023


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/75311

>From fada1451052c0b67c2f8411566f9707b6b814541 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 12 Dec 2023 15:55:29 +0100
Subject: [PATCH] [LVI] Support using block values when handling conditions

Currently, LVI will only use conditions like "X < C" to constrain
the value of X on the relevant edge. This patch extends it to
handle conditions like "X < Y" by querying the known range of Y.

This means that getValueFromCondition() and various related APIs
can now return nullopt to indicate that they have pushed to the
worklist, and need to be called again later. This behavior is
currently controlled by a UseBlockValue option, and only enabled
for actual edge value handling. All other places deriving
constraints from conditions keep using the previous logic for
now.
---
 llvm/lib/Analysis/LazyValueInfo.cpp           | 155 ++++++++++++------
 .../cond-using-block-value.ll                 |   8 +-
 2 files changed, 110 insertions(+), 53 deletions(-)

diff --git a/llvm/lib/Analysis/LazyValueInfo.cpp b/llvm/lib/Analysis/LazyValueInfo.cpp
index 89cc7ea15ec1d7..f7d87716482278 100644
--- a/llvm/lib/Analysis/LazyValueInfo.cpp
+++ b/llvm/lib/Analysis/LazyValueInfo.cpp
@@ -434,6 +434,28 @@ class LazyValueInfoImpl {
 
   void solve();
 
+  // For the following methods, if UseBlockValue is true, the function may
+  // push additional values to the worklist and return nullopt. If
+  // UseBlockValue is false, it will never return nullopt.
+
+  std::optional<ValueLatticeElement>
+  getValueFromSimpleICmpCondition(CmpInst::Predicate Pred, Value *RHS,
+                                  const APInt &Offset, Instruction *CxtI,
+                                  bool UseBlockValue);
+
+  std::optional<ValueLatticeElement>
+  getValueFromICmpCondition(Value *Val, ICmpInst *ICI, bool isTrueDest,
+                            bool UseBlockValue);
+
+  std::optional<ValueLatticeElement>
+  getValueFromCondition(Value *Val, Value *Cond, bool IsTrueDest,
+                        bool UseBlockValue, unsigned Depth = 0);
+
+  std::optional<ValueLatticeElement> getEdgeValueLocal(Value *Val,
+                                                       BasicBlock *BBFrom,
+                                                       BasicBlock *BBTo,
+                                                       bool UseBlockValue);
+
 public:
   /// This is the query interface to determine the lattice value for the
   /// specified Value* at the context instruction (if specified) or at the
@@ -755,14 +777,10 @@ LazyValueInfoImpl::solveBlockValuePHINode(PHINode *PN, BasicBlock *BB) {
   return Result;
 }
 
-static ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond,
-                                                 bool isTrueDest = true,
-                                                 unsigned Depth = 0);
-
 // If we can determine a constraint on the value given conditions assumed by
 // the program, intersect those constraints with BBLV
 void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
-        Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
+    Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) {
   BBI = BBI ? BBI : dyn_cast<Instruction>(Val);
   if (!BBI)
     return;
@@ -779,17 +797,21 @@ void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange(
     if (I->getParent() != BB || !isValidAssumeForContext(I, BBI))
       continue;
 
-    BBLV = intersect(BBLV, getValueFromCondition(Val, I->getArgOperand(0)));
+    BBLV = intersect(BBLV, *getValueFromCondition(Val, I->getArgOperand(0),
+                                                  /*IsTrueDest*/ true,
+                                                  /*UseBlockValue*/ false));
   }
 
   // If guards are not used in the module, don't spend time looking for them
   if (GuardDecl && !GuardDecl->use_empty() &&
       BBI->getIterator() != BB->begin()) {
-    for (Instruction &I : make_range(std::next(BBI->getIterator().getReverse()),
-                                     BB->rend())) {
+    for (Instruction &I :
+         make_range(std::next(BBI->getIterator().getReverse()), BB->rend())) {
       Value *Cond = nullptr;
       if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(Cond))))
-        BBLV = intersect(BBLV, getValueFromCondition(Val, Cond));
+        BBLV = intersect(BBLV,
+                         *getValueFromCondition(Val, Cond, /*IsTrueDest*/ true,
+                                                /*UseBlockValue*/ false));
     }
   }
 
@@ -886,10 +908,14 @@ LazyValueInfoImpl::solveBlockValueSelect(SelectInst *SI, BasicBlock *BB) {
   // If the value is undef, a different value may be chosen in
   // the select condition.
   if (isGuaranteedNotToBeUndef(Cond, AC)) {
-    TrueVal = intersect(TrueVal,
-                        getValueFromCondition(SI->getTrueValue(), Cond, true));
-    FalseVal = intersect(
-        FalseVal, getValueFromCondition(SI->getFalseValue(), Cond, false));
+    TrueVal =
+        intersect(TrueVal, *getValueFromCondition(SI->getTrueValue(), Cond,
+                                                  /*IsTrueDest*/ true,
+                                                  /*UseBlockValue*/ false));
+    FalseVal =
+        intersect(FalseVal, *getValueFromCondition(SI->getFalseValue(), Cond,
+                                                   /*IsTrueDest*/ false,
+                                                   /*UseBlockValue*/ false));
   }
 
   ValueLatticeElement Result = TrueVal;
@@ -1068,15 +1094,26 @@ static bool matchICmpOperand(APInt &Offset, Value *LHS, Value *Val,
 }
 
 /// Get value range for a "(Val + Offset) Pred RHS" condition.
-static ValueLatticeElement getValueFromSimpleICmpCondition(
-    CmpInst::Predicate Pred, Value *RHS, const APInt &Offset) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getValueFromSimpleICmpCondition(CmpInst::Predicate Pred,
+                                                   Value *RHS,
+                                                   const APInt &Offset,
+                                                   Instruction *CxtI,
+                                                   bool UseBlockValue) {
   ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(),
                          /*isFullSet=*/true);
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS))
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
     RHSRange = ConstantRange(CI->getValue());
-  else if (Instruction *I = dyn_cast<Instruction>(RHS))
+  } else if (UseBlockValue) {
+    std::optional<ValueLatticeElement> R =
+        getBlockValue(RHS, CxtI->getParent(), CxtI);
+    if (!R)
+      return std::nullopt;
+    RHSRange = toConstantRange(*R, RHS->getType());
+  } else if (Instruction *I = dyn_cast<Instruction>(RHS)) {
     if (auto *Ranges = I->getMetadata(LLVMContext::MD_range))
       RHSRange = getConstantRangeFromMetadata(*Ranges);
+  }
 
   ConstantRange TrueValues =
       ConstantRange::makeAllowedICmpRegion(Pred, RHSRange);
@@ -1103,8 +1140,8 @@ getRangeViaSLT(CmpInst::Predicate Pred, APInt RHS,
   return std::nullopt;
 }
 
-static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
-                                                     bool isTrueDest) {
+std::optional<ValueLatticeElement> LazyValueInfoImpl::getValueFromICmpCondition(
+    Value *Val, ICmpInst *ICI, bool isTrueDest, bool UseBlockValue) {
   Value *LHS = ICI->getOperand(0);
   Value *RHS = ICI->getOperand(1);
 
@@ -1128,11 +1165,13 @@ static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI,
   unsigned BitWidth = Ty->getScalarSizeInBits();
   APInt Offset(BitWidth, 0);
   if (matchICmpOperand(Offset, LHS, Val, EdgePred))
-    return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset);
+    return getValueFromSimpleICmpCondition(EdgePred, RHS, Offset, ICI,
+                                           UseBlockValue);
 
   CmpInst::Predicate SwappedPred = CmpInst::getSwappedPredicate(EdgePred);
   if (matchICmpOperand(Offset, RHS, Val, SwappedPred))
-    return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset);
+    return getValueFromSimpleICmpCondition(SwappedPred, LHS, Offset, ICI,
+                                           UseBlockValue);
 
   const APInt *Mask, *C;
   if (match(LHS, m_And(m_Specific(Val), m_APInt(Mask))) &&
@@ -1212,10 +1251,12 @@ static ValueLatticeElement getValueFromOverflowCondition(
   return ValueLatticeElement::getRange(NWR);
 }
 
-static ValueLatticeElement getValueFromCondition(
-    Value *Val, Value *Cond, bool IsTrueDest, unsigned Depth) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getValueFromCondition(Value *Val, Value *Cond,
+                                         bool IsTrueDest, bool UseBlockValue,
+                                         unsigned Depth) {
   if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cond))
-    return getValueFromICmpCondition(Val, ICI, IsTrueDest);
+    return getValueFromICmpCondition(Val, ICI, IsTrueDest, UseBlockValue);
 
   if (auto *EVI = dyn_cast<ExtractValueInst>(Cond))
     if (auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand()))
@@ -1227,7 +1268,7 @@ static ValueLatticeElement getValueFromCondition(
 
   Value *N;
   if (match(Cond, m_Not(m_Value(N))))
-    return getValueFromCondition(Val, N, !IsTrueDest, Depth);
+    return getValueFromCondition(Val, N, !IsTrueDest, UseBlockValue, Depth);
 
   Value *L, *R;
   bool IsAnd;
@@ -1238,19 +1279,23 @@ static ValueLatticeElement getValueFromCondition(
   else
     return ValueLatticeElement::getOverdefined();
 
-  ValueLatticeElement LV = getValueFromCondition(Val, L, IsTrueDest, Depth);
-  ValueLatticeElement RV = getValueFromCondition(Val, R, IsTrueDest, Depth);
+  std::optional<ValueLatticeElement> LV =
+      getValueFromCondition(Val, L, IsTrueDest, UseBlockValue, Depth);
+  std::optional<ValueLatticeElement> RV =
+      getValueFromCondition(Val, R, IsTrueDest, UseBlockValue, Depth);
+  if (!LV || !RV)
+    return std::nullopt;
 
   // if (L && R) -> intersect L and R
   // if (!(L || R)) -> intersect !L and !R
   // if (L || R) -> union L and R
   // if (!(L && R)) -> union !L and !R
   if (IsTrueDest ^ IsAnd) {
-    LV.mergeIn(RV);
-    return LV;
+    LV->mergeIn(*RV);
+    return *LV;
   }
 
-  return intersect(LV, RV);
+  return intersect(*LV, *RV);
 }
 
 // Return true if Usr has Op as an operand, otherwise false.
@@ -1302,8 +1347,9 @@ static ValueLatticeElement constantFoldUser(User *Usr, Value *Op,
 }
 
 /// Compute the value of Val on the edge BBFrom -> BBTo.
-static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
-                                             BasicBlock *BBTo) {
+std::optional<ValueLatticeElement>
+LazyValueInfoImpl::getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
+                                     BasicBlock *BBTo, bool UseBlockValue) {
   // TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we
   // know that v != 0.
   if (BranchInst *BI = dyn_cast<BranchInst>(BBFrom->getTerminator())) {
@@ -1324,13 +1370,16 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
 
       // If the condition of the branch is an equality comparison, we may be
       // able to infer the value.
-      ValueLatticeElement Result = getValueFromCondition(Val, Condition,
-                                                         isTrueDest);
-      if (!Result.isOverdefined())
+      std::optional<ValueLatticeElement> Result =
+          getValueFromCondition(Val, Condition, isTrueDest, UseBlockValue);
+      if (!Result)
+        return std::nullopt;
+
+      if (!Result->isOverdefined())
         return Result;
 
       if (User *Usr = dyn_cast<User>(Val)) {
-        assert(Result.isOverdefined() && "Result isn't overdefined");
+        assert(Result->isOverdefined() && "Result isn't overdefined");
         // Check with isOperationFoldable() first to avoid linearly iterating
         // over the operands unnecessarily which can be expensive for
         // instructions with many operands.
@@ -1356,8 +1405,8 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
             //    br i1 %Condition, label %then, label %else
             for (unsigned i = 0; i < Usr->getNumOperands(); ++i) {
               Value *Op = Usr->getOperand(i);
-              ValueLatticeElement OpLatticeVal =
-                  getValueFromCondition(Op, Condition, isTrueDest);
+              ValueLatticeElement OpLatticeVal = *getValueFromCondition(
+                  Op, Condition, isTrueDest, /*UseBlockValue*/ false);
               if (std::optional<APInt> OpConst =
                       OpLatticeVal.asConstantInteger()) {
                 Result = constantFoldUser(Usr, Op, *OpConst, DL);
@@ -1367,7 +1416,7 @@ static ValueLatticeElement getEdgeValueLocal(Value *Val, BasicBlock *BBFrom,
           }
         }
       }
-      if (!Result.isOverdefined())
+      if (!Result->isOverdefined())
         return Result;
     }
   }
@@ -1432,8 +1481,12 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
   if (Constant *VC = dyn_cast<Constant>(Val))
     return ValueLatticeElement::get(VC);
 
-  ValueLatticeElement LocalResult = getEdgeValueLocal(Val, BBFrom, BBTo);
-  if (hasSingleValue(LocalResult))
+  std::optional<ValueLatticeElement> LocalResult =
+      getEdgeValueLocal(Val, BBFrom, BBTo, /*UseBlockValue*/ true);
+  if (!LocalResult)
+    return std::nullopt;
+
+  if (hasSingleValue(*LocalResult))
     // Can't get any more precise here
     return LocalResult;
 
@@ -1453,7 +1506,7 @@ LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom,
   // but then the result is not cached.
   intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, CxtI);
 
-  return intersect(LocalResult, InBlock);
+  return intersect(*LocalResult, InBlock);
 }
 
 ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB,
@@ -1499,10 +1552,12 @@ getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB,
 
   std::optional<ValueLatticeElement> Result =
       getEdgeValue(V, FromBB, ToBB, CxtI);
-  if (!Result) {
+  while (!Result) {
+    // As the worklist only explicitly tracks block values (but not edge values)
+    // we may have to call solve() multiple times, as the edge value calculation
+    // may request additional block values.
     solve();
     Result = getEdgeValue(V, FromBB, ToBB, CxtI);
-    assert(Result && "More work to do after problem solved?");
   }
 
   LLVM_DEBUG(dbgs() << "  Result = " << *Result << "\n");
@@ -1528,13 +1583,17 @@ ValueLatticeElement LazyValueInfoImpl::getValueAtUse(const Use &U) {
       if (!isGuaranteedNotToBeUndef(SI->getCondition(), AC))
         break;
       if (CurrU->getOperandNo() == 1)
-        CondVal = getValueFromCondition(V, SI->getCondition(), true);
+        CondVal =
+            *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ true,
+                                   /*UseBlockValue*/ false);
       else if (CurrU->getOperandNo() == 2)
-        CondVal = getValueFromCondition(V, SI->getCondition(), false);
+        CondVal =
+            *getValueFromCondition(V, SI->getCondition(), /*IsTrueDest*/ false,
+                                   /*UseBlockValue*/ false);
     } else if (auto *PHI = dyn_cast<PHINode>(CurrI)) {
       // TODO: Use non-local query?
-      CondVal =
-          getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU), PHI->getParent());
+      CondVal = *getEdgeValueLocal(V, PHI->getIncomingBlock(*CurrU),
+                                   PHI->getParent(), /*UseBlockValue*/ false);
     }
     if (CondVal)
       VL = intersect(VL, *CondVal);
diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
index d30b31d317a6de..252f6596cedc5e 100644
--- a/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
+++ b/llvm/test/Transforms/CorrelatedValuePropagation/cond-using-block-value.ll
@@ -12,8 +12,7 @@ define void @test_icmp_from_implied_cond(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L2:%.*]], label [[END]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[B_CMP1:%.*]] = icmp ult i32 [[B]], 32
-; CHECK-NEXT:    call void @use(i1 [[B_CMP1]])
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    [[B_CMP2:%.*]] = icmp ult i32 [[B]], 31
 ; CHECK-NEXT:    call void @use(i1 [[B_CMP2]])
 ; CHECK-NEXT:    ret void
@@ -47,7 +46,7 @@ define i64 @test_sext_from_implied_cond(i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L2:%.*]], label [[END]]
 ; CHECK:       l2:
-; CHECK-NEXT:    [[SEXT:%.*]] = sext i32 [[B]] to i64
+; CHECK-NEXT:    [[SEXT:%.*]] = zext nneg i32 [[B]] to i64
 ; CHECK-NEXT:    ret i64 [[SEXT]]
 ; CHECK:       end:
 ; CHECK-NEXT:    ret i64 0
@@ -74,8 +73,7 @@ define void @test_icmp_from_implied_range(i16 %x, i32 %b) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i32 [[B]], [[A]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[L1:%.*]], label [[END:%.*]]
 ; CHECK:       l1:
-; CHECK-NEXT:    [[B_CMP1:%.*]] = icmp ult i32 [[B]], 65535
-; CHECK-NEXT:    call void @use(i1 [[B_CMP1]])
+; CHECK-NEXT:    call void @use(i1 true)
 ; CHECK-NEXT:    [[B_CMP2:%.*]] = icmp ult i32 [[B]], 65534
 ; CHECK-NEXT:    call void @use(i1 [[B_CMP2]])
 ; CHECK-NEXT:    ret void



More information about the llvm-commits mailing list