[llvm] c6e1366 - [PredicateInfo] Add a method to interpret predicate as cmp constraint

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 19 06:34:40 PDT 2020


Author: Nikita Popov
Date: 2020-07-19T15:34:32+02:00
New Revision: c6e13667e787b3a72b794422ab506d5403ddcd21

URL: https://github.com/llvm/llvm-project/commit/c6e13667e787b3a72b794422ab506d5403ddcd21
DIFF: https://github.com/llvm/llvm-project/commit/c6e13667e787b3a72b794422ab506d5403ddcd21.diff

LOG: [PredicateInfo] Add a method to interpret predicate as cmp constraint

Both users of predicteinfo (NewGVN and SCCP) are interested in
getting a cmp constraint on the predicated value. They currently
implement separate logic for this. This patch adds a common method
for this in PredicateBase.

This enables a missing bit of PredicateInfo handling in SCCP: Now
the predicate on the condition itself is also used. For switches
it means we know that the switched-on value is the same as the case
value. For assumes/branches we know that the condition is true or
false.

Differential Revision: https://reviews.llvm.org/D83640

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/PredicateInfo.h
    llvm/lib/Transforms/Scalar/NewGVN.cpp
    llvm/lib/Transforms/Scalar/SCCP.cpp
    llvm/lib/Transforms/Utils/PredicateInfo.cpp
    llvm/test/Transforms/SCCP/predicateinfo-cond.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h
index cdac4142555d..c922476ac79d 100644
--- a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h
+++ b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h
@@ -70,6 +70,13 @@ class raw_ostream;
 
 enum PredicateType { PT_Branch, PT_Assume, PT_Switch };
 
+/// Constraint for a predicate of the form "cmp Pred Op, OtherOp", where Op
+/// is the value the constraint applies to (the ssa.copy result).
+struct PredicateConstraint {
+  CmpInst::Predicate Predicate;
+  Value *OtherOp;
+};
+
 // Base class for all predicate information we provide.
 // All of our predicate information has at least a comparison.
 class PredicateBase : public ilist_node<PredicateBase> {
@@ -95,6 +102,9 @@ class PredicateBase : public ilist_node<PredicateBase> {
            PB->Type == PT_Switch;
   }
 
+  /// Fetch condition in the form of PredicateConstraint, if possible.
+  Optional<PredicateConstraint> getConstraint() const;
+
 protected:
   PredicateBase(PredicateType PT, Value *Op, Value *Condition)
       : Type(PT), OriginalOp(Op), Condition(Condition) {}

diff  --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index 45d01cc1b584..cfadfbb585b9 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -1539,86 +1539,39 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const {
 
   LLVM_DEBUG(dbgs() << "Found predicate info from instruction !\n");
 
-  auto *CopyOf = I->getOperand(0);
-  auto *Cond = PI->Condition;
-
-  // If this a copy of the condition, it must be either true or false depending
-  // on the predicate info type and edge.
-  if (CopyOf == Cond) {
-    // We should not need to add predicate users because the predicate info is
-    // already a use of this operand.
-    if (isa<PredicateAssume>(PI))
-      return createConstantExpression(ConstantInt::getTrue(Cond->getType()));
-    if (auto *PBranch = dyn_cast<PredicateBranch>(PI)) {
-      if (PBranch->TrueEdge)
-        return createConstantExpression(ConstantInt::getTrue(Cond->getType()));
-      return createConstantExpression(ConstantInt::getFalse(Cond->getType()));
-    }
-    if (auto *PSwitch = dyn_cast<PredicateSwitch>(PI))
-      return createConstantExpression(cast<Constant>(PSwitch->CaseValue));
-  }
-
-  // Not a copy of the condition, so see what the predicates tell us about this
-  // value.  First, though, we check to make sure the value is actually a copy
-  // of one of the condition operands. It's possible, in certain cases, for it
-  // to be a copy of a predicateinfo copy. In particular, if two branch
-  // operations use the same condition, and one branch dominates the other, we
-  // will end up with a copy of a copy.  This is currently a small deficiency in
-  // predicateinfo.  What will end up happening here is that we will value
-  // number both copies the same anyway.
-
-  // Everything below relies on the condition being a comparison.
-  auto *Cmp = dyn_cast<CmpInst>(Cond);
-  if (!Cmp)
+  const Optional<PredicateConstraint> &Constraint = PI->getConstraint();
+  if (!Constraint)
     return nullptr;
 
-  if (CopyOf != Cmp->getOperand(0) && CopyOf != Cmp->getOperand(1)) {
-    LLVM_DEBUG(dbgs() << "Copy is not of any condition operands!\n");
-    return nullptr;
-  }
-  Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0));
-  Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1));
-  bool SwappedOps = false;
+  CmpInst::Predicate Predicate = Constraint->Predicate;
+  Value *CmpOp0 = I->getOperand(0);
+  Value *CmpOp1 = Constraint->OtherOp;
+
+  Value *FirstOp = lookupOperandLeader(CmpOp0);
+  Value *SecondOp = lookupOperandLeader(CmpOp1);
+  Value *AdditionallyUsedValue = CmpOp0;
+
   // Sort the ops.
   if (shouldSwapOperands(FirstOp, SecondOp)) {
     std::swap(FirstOp, SecondOp);
-    SwappedOps = true;
+    Predicate = CmpInst::getSwappedPredicate(Predicate);
+    AdditionallyUsedValue = CmpOp1;
   }
-  CmpInst::Predicate Predicate =
-      SwappedOps ? Cmp->getSwappedPredicate() : Cmp->getPredicate();
-
-  if (isa<PredicateAssume>(PI)) {
-    // If we assume the operands are equal, then they are equal.
-    if (Predicate == CmpInst::ICMP_EQ) {
-      addPredicateUsers(PI, I);
-      addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0),
-                         I);
-      return createVariableOrConstant(FirstOp);
-    }
+
+  if (Predicate == CmpInst::ICMP_EQ) {
+    addPredicateUsers(PI, I);
+    addAdditionalUsers(AdditionallyUsedValue, I);
+    return createVariableOrConstant(FirstOp);
   }
-  if (const auto *PBranch = dyn_cast<PredicateBranch>(PI)) {
-    // If we are *not* a copy of the comparison, we may equal to the other
-    // operand when the predicate implies something about equality of
-    // operations.  In particular, if the comparison is true/false when the
-    // operands are equal, and we are on the right edge, we know this operation
-    // is equal to something.
-    if ((PBranch->TrueEdge && Predicate == CmpInst::ICMP_EQ) ||
-        (!PBranch->TrueEdge && Predicate == CmpInst::ICMP_NE)) {
-      addPredicateUsers(PI, I);
-      addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0),
-                         I);
-      return createVariableOrConstant(FirstOp);
-    }
-    // Handle the special case of floating point.
-    if (((PBranch->TrueEdge && Predicate == CmpInst::FCMP_OEQ) ||
-         (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) &&
-        isa<ConstantFP>(FirstOp) && !cast<ConstantFP>(FirstOp)->isZero()) {
-      addPredicateUsers(PI, I);
-      addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0),
-                         I);
-      return createConstantExpression(cast<Constant>(FirstOp));
-    }
+
+  // Handle the special case of floating point.
+  if (Predicate == CmpInst::FCMP_OEQ && isa<ConstantFP>(FirstOp) &&
+      !cast<ConstantFP>(FirstOp)->isZero()) {
+    addPredicateUsers(PI, I);
+    addAdditionalUsers(AdditionallyUsedValue, I);
+    return createConstantExpression(cast<Constant>(FirstOp));
   }
+
   return nullptr;
 }
 

diff  --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp
index 2a5fcfc09268..11ac7d7e1584 100644
--- a/llvm/lib/Transforms/Scalar/SCCP.cpp
+++ b/llvm/lib/Transforms/Scalar/SCCP.cpp
@@ -1262,55 +1262,22 @@ void SCCPSolver::handleCallResult(CallBase &CB) {
       auto *PI = getPredicateInfoFor(&CB);
       assert(PI && "Missing predicate info for ssa.copy");
 
-      CmpInst *Cmp;
-      bool TrueEdge;
-      if (auto *PBranch = dyn_cast<PredicateBranch>(PI)) {
-        Cmp = dyn_cast<CmpInst>(PBranch->Condition);
-        TrueEdge = PBranch->TrueEdge;
-      } else if (auto *PAssume = dyn_cast<PredicateAssume>(PI)) {
-        Cmp = dyn_cast<CmpInst>(PAssume->Condition);
-        TrueEdge = true;
-      } else {
+      const Optional<PredicateConstraint> &Constraint = PI->getConstraint();
+      if (!Constraint) {
         mergeInValue(ValueState[&CB], &CB, CopyOfVal);
         return;
       }
 
-      // Everything below relies on the condition being a comparison.
-      if (!Cmp) {
-        mergeInValue(ValueState[&CB], &CB, CopyOfVal);
-        return;
-      }
+      CmpInst::Predicate Pred = Constraint->Predicate;
+      Value *OtherOp = Constraint->OtherOp;
 
-      Value *RenamedOp = PI->RenamedOp;
-      Value *CmpOp0 = Cmp->getOperand(0);
-      Value *CmpOp1 = Cmp->getOperand(1);
-      // Bail out if neither of the operands matches RenamedOp.
-      if (CmpOp0 != RenamedOp && CmpOp1 != RenamedOp) {
-        mergeInValue(ValueState[&CB], &CB, getValueState(CopyOf));
+      // Wait until OtherOp is resolved.
+      if (getValueState(OtherOp).isUnknown()) {
+        addAdditionalUser(OtherOp, &CB);
         return;
       }
 
-      auto Pred = Cmp->getPredicate();
-      if (CmpOp1 == RenamedOp) {
-        std::swap(CmpOp0, CmpOp1);
-        Pred = Cmp->getSwappedPredicate();
-      }
-
-      // Wait until CmpOp1 is resolved.
-      if (getValueState(CmpOp1).isUnknown()) {
-        addAdditionalUser(CmpOp1, &CB);
-        return;
-      }
-
-      // The code below relies on PredicateInfo only inserting copies for the
-      // true branch when the branch condition is an AND and only inserting
-      // copies for the false branch when the branch condition is an OR. This
-      // ensures we can intersect the range from the condition with the range of
-      // CopyOf.
-      if (!TrueEdge)
-        Pred = CmpInst::getInversePredicate(Pred);
-
-      ValueLatticeElement CondVal = getValueState(CmpOp1);
+      ValueLatticeElement CondVal = getValueState(OtherOp);
       ValueLatticeElement &IV = ValueState[&CB];
       if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) {
         auto ImposedCR =
@@ -1334,7 +1301,7 @@ void SCCPSolver::handleCallResult(CallBase &CB) {
         if (!CopyOfCR.contains(NewCR) && CopyOfCR.getSingleMissingElement())
           NewCR = CopyOfCR;
 
-        addAdditionalUser(CmpOp1, &CB);
+        addAdditionalUser(OtherOp, &CB);
         // TODO: Actually filp MayIncludeUndef for the created range to false,
         // once most places in the optimizer respect the branches on
         // undef/poison are UB rule. The reason why the new range cannot be
@@ -1351,7 +1318,7 @@ void SCCPSolver::handleCallResult(CallBase &CB) {
       } else if (Pred == CmpInst::ICMP_EQ && CondVal.isConstant()) {
         // For non-integer values or integer constant expressions, only
         // propagate equal constants.
-        addAdditionalUser(CmpOp1, &CB);
+        addAdditionalUser(OtherOp, &CB);
         mergeInValue(IV, &CB, CondVal);
         return;
       }

diff  --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
index 99b64a7462f6..280d3a996d50 100644
--- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp
+++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
@@ -822,6 +822,53 @@ PredicateInfo::~PredicateInfo() {
   }
 }
 
+Optional<PredicateConstraint> PredicateBase::getConstraint() const {
+  switch (Type) {
+  case PT_Assume:
+  case PT_Branch: {
+    bool TrueEdge = true;
+    if (auto *PBranch = dyn_cast<PredicateBranch>(this))
+      TrueEdge = PBranch->TrueEdge;
+
+    if (Condition == RenamedOp) {
+      return {{CmpInst::ICMP_EQ,
+               TrueEdge ? ConstantInt::getTrue(Condition->getType())
+                        : ConstantInt::getFalse(Condition->getType())}};
+    }
+
+    CmpInst *Cmp = dyn_cast<CmpInst>(Condition);
+    assert(Cmp && "Condition should be a CmpInst");
+
+    CmpInst::Predicate Pred;
+    Value *OtherOp;
+    if (Cmp->getOperand(0) == RenamedOp) {
+      Pred = Cmp->getPredicate();
+      OtherOp = Cmp->getOperand(1);
+    } else if (Cmp->getOperand(1) == RenamedOp) {
+      Pred = Cmp->getSwappedPredicate();
+      OtherOp = Cmp->getOperand(0);
+    } else {
+      // TODO: Make this an assertion once RenamedOp is fully accurate.
+      return None;
+    }
+
+    // Invert predicate along false edge.
+    if (!TrueEdge)
+      Pred = CmpInst::getInversePredicate(Pred);
+
+    return {{Pred, OtherOp}};
+  }
+  case PT_Switch:
+    if (Condition != RenamedOp) {
+      // TODO: Make this an assertion once RenamedOp is fully accurate.
+      return None;
+    }
+
+    return {{CmpInst::ICMP_EQ, cast<PredicateSwitch>(this)->CaseValue}};
+  }
+  llvm_unreachable("Unknown predicate type");
+}
+
 void PredicateInfo::verifyPredicateInfo() const {}
 
 char PredicateInfoPrinterLegacyPass::ID = 0;

diff  --git a/llvm/test/Transforms/SCCP/predicateinfo-cond.ll b/llvm/test/Transforms/SCCP/predicateinfo-cond.ll
index d8528918babe..d98b3cc76d92 100644
--- a/llvm/test/Transforms/SCCP/predicateinfo-cond.ll
+++ b/llvm/test/Transforms/SCCP/predicateinfo-cond.ll
@@ -11,16 +11,13 @@ define i32 @switch(i32 %x) {
 ; CHECK-NEXT:    i32 2, label [[CASE_2:%.*]]
 ; CHECK-NEXT:    ]
 ; CHECK:       case.0:
-; CHECK-NEXT:    [[ADD:%.*]] = add i32 [[X]], 1
 ; CHECK-NEXT:    br label [[END:%.*]]
 ; CHECK:       case.2:
-; CHECK-NEXT:    [[SUB:%.*]] = sub i32 [[X]], 1
 ; CHECK-NEXT:    br label [[END]]
 ; CHECK:       case.default:
 ; CHECK-NEXT:    br label [[END]]
 ; CHECK:       end:
-; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ [[ADD]], [[CASE_0]] ], [ [[SUB]], [[CASE_2]] ], [ 1, [[CASE_DEFAULT]] ]
-; CHECK-NEXT:    ret i32 [[PHI]]
+; CHECK-NEXT:    ret i32 1
 ;
   switch i32 %x, label %case.default [
   i32 0, label %case.0
@@ -47,7 +44,7 @@ define i1 @assume(i32 %x) {
 ; CHECK-LABEL: @assume(
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp sge i32 [[X:%.*]], 0
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %cmp = icmp sge i32 %x, 0
   call void @llvm.assume(i1 %cmp)
@@ -59,23 +56,17 @@ define i32 @branch(i32 %x) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp sge i32 [[X:%.*]], 0
 ; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN1:%.*]], label [[IF_THEN2:%.*]]
 ; CHECK:       if.then1:
-; CHECK-NEXT:    br i1 [[CMP]], label [[IF2_THEN1:%.*]], label [[IF2_THEN2:%.*]]
+; CHECK-NEXT:    br label [[IF2_THEN1:%.*]]
 ; CHECK:       if2.then1:
 ; CHECK-NEXT:    br label [[IF2_END:%.*]]
-; CHECK:       if2.then2:
-; CHECK-NEXT:    br label [[IF2_END]]
 ; CHECK:       if2.end:
-; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ 0, [[IF2_THEN1]] ], [ 1, [[IF2_THEN2]] ]
-; CHECK-NEXT:    ret i32 [[PHI]]
+; CHECK-NEXT:    ret i32 0
 ; CHECK:       if.then2:
-; CHECK-NEXT:    br i1 [[CMP]], label [[IF3_THEN1:%.*]], label [[IF3_THEN2:%.*]]
-; CHECK:       if3.then1:
-; CHECK-NEXT:    br label [[IF3_END:%.*]]
+; CHECK-NEXT:    br label [[IF3_THEN2:%.*]]
 ; CHECK:       if3.then2:
-; CHECK-NEXT:    br label [[IF3_END]]
+; CHECK-NEXT:    br label [[IF3_END:%.*]]
 ; CHECK:       if3.end:
-; CHECK-NEXT:    [[PHI2:%.*]] = phi i32 [ 0, [[IF3_THEN1]] ], [ 1, [[IF3_THEN2]] ]
-; CHECK-NEXT:    ret i32 [[PHI2]]
+; CHECK-NEXT:    ret i32 1
 ;
   %cmp = icmp sge i32 %x, 0
   br i1 %cmp, label %if.then1, label %if.then2


        


More information about the llvm-commits mailing list