[llvm] e38ccb7 - Recommit "Generalize getInvertibleOperand recurrence handling slightly"

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon May 3 16:41:01 PDT 2021


Author: Philip Reames
Date: 2021-05-03T16:40:56-07:00
New Revision: e38ccb729b205b076356684e055efb7dfc673963

URL: https://github.com/llvm/llvm-project/commit/e38ccb729b205b076356684e055efb7dfc673963
DIFF: https://github.com/llvm/llvm-project/commit/e38ccb729b205b076356684e055efb7dfc673963.diff

LOG: Recommit "Generalize getInvertibleOperand recurrence handling slightly"

This was reverted because of a reported problem.  It turned out this patch didn't introduce said problem, it just exposed it more widely.  15a4233 fixes the root issue, so this simple a) rebases over that, and b) adds a much more extensive comment explaining why that weakened assert is correct.

Original commit message follows:

Follow up to D99912, specifically the revert, fix, and reapply thereof.

This generalizes the invertible recurrence logic in two ways:
* By allowing mismatching operand numbers of the phi, we can recurse through a pair of phi recurrences whose operand orders have not been canonicalized.
* By allowing recurrences through operand 1, we can invert these odd (but legal) recurrence.

Differential Revision: https://reviews.llvm.org/D100884

Added: 
    

Modified: 
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/test/Analysis/ValueTracking/known-non-equal.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 30c64f1cc4dcc..3621deef60da0 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -2521,26 +2521,31 @@ bool isKnownNonZero(const Value* V, unsigned Depth, const Query& Q) {
   return isKnownNonZero(V, DemandedElts, Depth, Q);
 }
 
-/// If the pair of operators are the same invertible function of a single
-/// operand return the index of that operand.  Otherwise, return None.  An
-/// invertible function is one that is 1-to-1 and maps every input value
-/// to exactly one output value.  This is equivalent to saying that Op1
-/// and Op2 are equal exactly when the specified pair of operands are equal,
-/// (except that Op1 and Op2 may be poison more often.)
-static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
-                                               const Operator *Op2) {
+/// If the pair of operators are the same invertible function, return the
+/// the operands of the function corresponding to each input. Otherwise,
+/// return None.  An invertible function is one that is 1-to-1 and maps
+/// every input value to exactly one output value.  This is equivalent to
+/// saying that Op1 and Op2 are equal exactly when the specified pair of
+/// operands are equal, (except that Op1 and Op2 may be poison more often.)
+static Optional<std::pair<Value*, Value*>>
+getInvertibleOperands(const Operator *Op1,
+                      const Operator *Op2) {
   if (Op1->getOpcode() != Op2->getOpcode())
     return None;
 
+  auto getOperands = [&](unsigned OpNum) -> auto {
+    return std::make_pair(Op1->getOperand(OpNum), Op2->getOperand(OpNum));
+  };
+
   switch (Op1->getOpcode()) {
   default:
     break;
   case Instruction::Add:
   case Instruction::Sub:
     if (Op1->getOperand(0) == Op2->getOperand(0))
-      return 1;
+      return getOperands(1);
     if (Op1->getOperand(1) == Op2->getOperand(1))
-      return 0;
+      return getOperands(0);
     break;
   case Instruction::Mul: {
     // invertible if A * B == (A * B) mod 2^N where A, and B are integers
@@ -2556,7 +2561,7 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
     if (Op1->getOperand(1) == Op2->getOperand(1) &&
         isa<ConstantInt>(Op1->getOperand(1)) &&
         !cast<ConstantInt>(Op1->getOperand(1))->isZero())
-      return 0;
+      return getOperands(0);
     break;
   }
   case Instruction::Shl: {
@@ -2569,7 +2574,7 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
       break;
 
     if (Op1->getOperand(1) == Op2->getOperand(1))
-      return 0;
+      return getOperands(0);
     break;
   }
   case Instruction::AShr:
@@ -2580,13 +2585,13 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
       break;
 
     if (Op1->getOperand(1) == Op2->getOperand(1))
-      return 0;
+      return getOperands(0);
     break;
   }
   case Instruction::SExt:
   case Instruction::ZExt:
     if (Op1->getOperand(0)->getType() == Op2->getOperand(0)->getType())
-      return 0;
+      return getOperands(0);
     break;
   case Instruction::PHI: {
     const PHINode *PN1 = cast<PHINode>(Op1);
@@ -2604,19 +2609,20 @@ static Optional<unsigned> getInvertibleOperand(const Operator *Op1,
         !matchSimpleRecurrence(PN2, BO2, Start2, Step2))
       break;
 
-    Optional<unsigned> Idx = getInvertibleOperand(cast<Operator>(BO1),
-                                                  cast<Operator>(BO2));
-    if (!Idx || *Idx != 0)
-      break;
-    if (BO1->getOperand(*Idx) != PN1 || BO2->getOperand(*Idx) != PN2)
+    auto Values = getInvertibleOperands(cast<Operator>(BO1),
+                                        cast<Operator>(BO2));
+    if (!Values)
+       break;
+
+    // We have to be careful of mutually defined recurrences here.  Ex:
+    // * X_i = X_(i-1) OP Y_(i-1), and Y_i = X_(i-1) OP V
+    // * X_i = Y_i = X_(i-1) OP Y_(i-1)
+    // The invertibility of these is complicated, and not worth reasoning
+    // about (yet?).
+    if (Values->first != PN1 || Values->second != PN2)
       break;
 
-    // Phi operands might not be in the same order.  TODO: generalize
-    // interface to return pair of operands.
-    if (PN1->getOperand(0) == BO1 && PN2->getOperand(0) == BO2)
-      return 1;
-    if (PN1->getOperand(1) == BO1 && PN2->getOperand(1) == BO2)
-      return 0;
+    return std::make_pair(Start1, Start2);
   }
   }
   return None;
@@ -2713,11 +2719,9 @@ static bool isKnownNonEqual(const Value *V1, const Value *V2, unsigned Depth,
   auto *O1 = dyn_cast<Operator>(V1);
   auto *O2 = dyn_cast<Operator>(V2);
   if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
-    if (Optional<unsigned> Opt = getInvertibleOperand(O1, O2)) {
-      unsigned Idx = *Opt;
-      return isKnownNonEqual(O1->getOperand(Idx), O2->getOperand(Idx),
-                             Depth + 1, Q);
-    }
+    if (auto Values = getInvertibleOperands(O1, O2))
+      return isKnownNonEqual(Values->first, Values->second, Depth + 1, Q);
+
     if (const PHINode *PN1 = dyn_cast<PHINode>(V1)) {
       const PHINode *PN2 = cast<PHINode>(V2);
       // FIXME: This is missing a generalization to handle the case where one is

diff  --git a/llvm/test/Analysis/ValueTracking/known-non-equal.ll b/llvm/test/Analysis/ValueTracking/known-non-equal.ll
index c1a8f07953a3f..2e53f8c37d36a 100644
--- a/llvm/test/Analysis/ValueTracking/known-non-equal.ll
+++ b/llvm/test/Analysis/ValueTracking/known-non-equal.ll
@@ -736,8 +736,7 @@ define i1 @recurrence_add_op_order(i8 %A) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %B = add i8 %A, 1
@@ -808,8 +807,7 @@ define i1 @recurrence_add_phi_
diff erent_order1(i8 %A) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %B = add i8 %A, 1
@@ -843,8 +841,7 @@ define i1 @recurrence_add_phi_
diff erent_order2(i8 %A) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %B = add i8 %A, 1
@@ -979,8 +976,7 @@ define i1 @recurrence_sub_op_order(i8 %A) {
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]]
 ; CHECK:       exit:
-; CHECK-NEXT:    [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 false
 ;
 entry:
   %B = add i8 %A, 1


        


More information about the llvm-commits mailing list