[llvm] 0c01b37 - Generalize getInvertibleOperand recurrence handling slightly

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 28 14:38:20 PDT 2021


Author: Philip Reames
Date: 2021-04-28T14:38:07-07:00
New Revision: 0c01b37eeb18a51a7e9c9153330d8009de0f600e

URL: https://github.com/llvm/llvm-project/commit/0c01b37eeb18a51a7e9c9153330d8009de0f600e
DIFF: https://github.com/llvm/llvm-project/commit/0c01b37eeb18a51a7e9c9153330d8009de0f600e.diff

LOG: Generalize getInvertibleOperand recurrence handling slightly

Follow up to D99912, specifically the revert, fix, and reapply thereof.

This generalizes the invertible recurrence logic in two ways:
* By allowing mismatching operand numbers of the phi, we can recurse through a pair of phi recurrences whose operand orders have not been canonicalized.
* By allowing recurrences through operand 1, we can invert these odd (but legal) recurrence.

Differential Revision: https://reviews.llvm.org/D100884

Added: 
    

Modified: 
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/test/Analysis/ValueTracking/known-non-equal.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 7b2cd163b6ee..e864e9903d0f 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2513,26 +2513,31 @@ bool isKnownNonZero(const Value* V, unsigned Depth, const Query& Q) {
   return isKnownNonZero(V, DemandedElts, Depth, Q);
 }
 
-/// If the pair of operators are the same invertible function of a single
-/// operand return the index of that operand.  Otherwise, return None.  An
-/// invertible function is one that is 1-to-1 and maps every input value
-/// to exactly one output value.  This is equivalent to saying that Op1
-/// and Op2 are equal exactly when the specified pair of operands are equal,
-/// (except that Op1 and Op2 may be poison more often.)
-static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
-                                               const Operator *Op2) {
+/// If the pair of operators are the same invertible function, return the
+/// the operands of the function corresponding to each input. Otherwise,
+/// return None.  An invertible function is one that is 1-to-1 and maps
+/// every input value to exactly one output value.  This is equivalent to
+/// saying that Op1 and Op2 are equal exactly when the specified pair of
+/// operands are equal, (except that Op1 and Op2 may be poison more often.)
+static Optional<std::pair<Value*, Value*>>
+getInvertibleOperands(const Operator *Op1,
+                      const Operator *Op2) {
   if (Op1->getOpcode() != Op2->getOpcode())
     return None;
 
+  auto getOperands = [&](unsigned OpNum) -> auto {
+    return std::make_pair(Op1->getOperand(OpNum), Op2->getOperand(OpNum));
+  };
+
   switch (Op1->getOpcode()) {
   default:
     break;
   case Instruction::Add:
   case Instruction::Sub:
     if (Op1->getOperand(0) == Op2->getOperand(0))
-      return 1;
+      return getOperands(1);
     if (Op1->getOperand(1) == Op2->getOperand(1))
-      return 0;
+      return getOperands(0);
     break;
   case Instruction::Mul: {
     // invertible if A * B == (A * B) mod 2^N where A, and B are integers
@@ -2548,7 +2553,7 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
     if (Op1->getOperand(1) == Op2->getOperand(1) &&
         isa<ConstantInt>(Op1->getOperand(1)) &&
         !cast<ConstantInt>(Op1->getOperand(1))->isZero())
-      return 0;
+      return getOperands(0);
     break;
   }
   case Instruction::Shl: {
@@ -2561,7 +2566,7 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
       break;
 
     if (Op1->getOperand(1) == Op2->getOperand(1))
-      return 0;
+      return getOperands(0);
     break;
   }
   case Instruction::AShr:
@@ -2572,13 +2577,13 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
       break;
 
     if (Op1->getOperand(1) == Op2->getOperand(1))
-      return 0;
+      return getOperands(0);
     break;
   }
   case Instruction::SExt:
   case Instruction::ZExt:
     if (Op1->getOperand(0)->getType() == Op2->getOperand(0)->getType())
-      return 0;
+      return getOperands(0);
     break;
   case Instruction::PHI: {
     const PHINode *PN1 = cast<PHINode>(Op1);
@@ -2596,18 +2601,12 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
         !matchSimpleRecurrence(PN2, BO2, Start2, Step2))
       break;
 
-    Optional<unsigned> Idx = getInvertibleOperand(cast<Operator>(BO1),
-                                                  cast<Operator>(BO2));
-    if (!Idx || *Idx != 0)
+    auto Values = getInvertibleOperands(cast<Operator>(BO1),
+                                        cast<Operator>(BO2));
+    if (!Values)
       break;
-    assert(BO1->getOperand(*Idx) == PN1 && BO2->getOperand(*Idx) == PN2);
-
-    // Phi operands might not be in the same order.  TODO: generalize
-    // interface to return pair of operands.
-    if (PN1->getOperand(0) == BO1 && PN2->getOperand(0) == BO2)
-      return 1;
-    if (PN1->getOperand(1) == BO1 && PN2->getOperand(1) == BO2)
-      return 0;
+    assert(Values->first == PN1 && Values->second == PN2);
+    return std::make_pair(Start1, Start2);
   }
   }
   return None;
@@ -2704,11 +2703,9 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
   auto *O1 = dyn_cast<Operator>(V1);
   auto *O2 = dyn_cast<Operator>(V2);
   if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
-    if (Optional<unsigned> Opt = getInvertibleOperand(O1, O2)) {
-      unsigned Idx = *Opt;
-      return isKnownNonEqual(O1->getOperand(Idx), O2->getOperand(Idx),
-                             Depth + 1, Q);
-    }
+    if (auto Values = getInvertibleOperands(O1, O2))
+      return isKnownNonEqual(Values->first, Values->second, Depth + 1, Q);
+
     if (const PHINode *PN1 = dyn_cast<PHINode>(V1)) {
       const PHINode *PN2 = cast<PHINode>(V2);
       // FIXME: This is missing a generalization to handle the case where one is

diff  --git a/llvm/test/Analysis/ValueTracking/known-non-equal.ll b/llvm/test/Analysis/ValueTracking/known-non-equal.ll
index a41d228fd403..f3e3b6dbfaf0 100644
--- a/llvm/test/Analysis/ValueTracking/known-non-equal.ll
+++ b/llvm/test/Analysis/ValueTracking/known-non-equal.ll
@@ -736,8 +736,7 @@ define i1 @recurrence_add_op_order(i8 %A) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %B = add i8 %A, 1
@@ -808,8 +807,7 @@ define i1 @recurrence_add_phi_
diff erent_order1(i8 %A) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %B = add i8 %A, 1
@@ -843,8 +841,7 @@ define i1 @recurrence_add_phi_
diff erent_order2(i8 %A) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %B = add i8 %A, 1
@@ -979,8 +976,7 @@ define i1 @recurrence_sub_op_order(i8 %A) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %B = add i8 %A, 1


        


More information about the llvm-commits mailing list