[llvm] 1d8ef76 - [NewGVN] Use ExprResult to add extra predicate users.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 25 03:13:56 PDT 2021


Author: Florian Hahn
Date: 2021-04-25T11:13:32+01:00
New Revision: 1d8ef761be68d7ba023aea0450be5355426d46f7

URL: https://github.com/llvm/llvm-project/commit/1d8ef761be68d7ba023aea0450be5355426d46f7
DIFF: https://github.com/llvm/llvm-project/commit/1d8ef761be68d7ba023aea0450be5355426d46f7.diff

LOG: [NewGVN] Use ExprResult to add extra predicate users.

This patch updates performSymbolicPredicateInfoEvaluation to manage
registering additional dependencies using ExprResult. Similar to D99987,
this fixes an issues where we failed to track the correct dependency for
a phi-of-ops value, which is marked as temporary.

Fixes PR49873.

Reviewed By: asbirlea, ruiling

Differential Revision: https://reviews.llvm.org/D100560

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/NewGVN.cpp
    llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index 6b2932f63dda7..86457b3595458 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -672,9 +672,11 @@ class NewGVN {
   struct ExprResult {
     const Expression *Expr;
     Value *ExtraDep;
+    const PredicateBase *PredDep;
 
-    ExprResult(const Expression *Expr, Value *ExtraDep = nullptr)
-        : Expr(Expr), ExtraDep(ExtraDep) {}
+    ExprResult(const Expression *Expr, Value *ExtraDep = nullptr,
+               const PredicateBase *PredDep = nullptr)
+        : Expr(Expr), ExtraDep(ExtraDep), PredDep(PredDep) {}
     ExprResult(const ExprResult &) = delete;
     ExprResult(ExprResult &&Other)
         : Expr(Other.Expr), ExtraDep(Other.ExtraDep) {
@@ -688,9 +690,17 @@ class NewGVN {
 
     operator bool() const { return Expr; }
 
-    static ExprResult none() { return {nullptr, nullptr}; }
+    static ExprResult none() { return {nullptr, nullptr, nullptr}; }
     static ExprResult some(const Expression *Expr, Value *ExtraDep = nullptr) {
-      return {Expr, ExtraDep};
+      return {Expr, ExtraDep, nullptr};
+    }
+    static ExprResult some(const Expression *Expr,
+                           const PredicateBase *PredDep) {
+      return {Expr, nullptr, PredDep};
+    }
+    static ExprResult some(const Expression *Expr, Value *ExtraDep,
+                           const PredicateBase *PredDep) {
+      return {Expr, ExtraDep, PredDep};
     }
   };
 
@@ -776,14 +786,14 @@ class NewGVN {
                                                 MemoryAccess *) const;
   const Expression *performSymbolicLoadEvaluation(Instruction *) const;
   const Expression *performSymbolicStoreEvaluation(Instruction *) const;
-  const Expression *performSymbolicCallEvaluation(Instruction *) const;
+  ExprResult performSymbolicCallEvaluation(Instruction *) const;
   void sortPHIOps(MutableArrayRef<ValPair> Ops) const;
   const Expression *performSymbolicPHIEvaluation(ArrayRef<ValPair>,
                                                  Instruction *I,
                                                  BasicBlock *PHIBlock) const;
   const Expression *performSymbolicAggrValueEvaluation(Instruction *) const;
   ExprResult performSymbolicCmpEvaluation(Instruction *) const;
-  const Expression *performSymbolicPredicateInfoEvaluation(Instruction *) const;
+  ExprResult performSymbolicPredicateInfoEvaluation(Instruction *) const;
 
   // Congruence finding.
   bool someEquivalentDominates(const Instruction *, const Instruction *) const;
@@ -836,10 +846,9 @@ class NewGVN {
   void markValueLeaderChangeTouched(CongruenceClass *CC);
   void markMemoryLeaderChangeTouched(CongruenceClass *CC);
   void markPhiOfOpsChanged(const Expression *E);
-  void addPredicateUsers(const PredicateBase *, Instruction *) const;
   void addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) const;
   void addAdditionalUsers(Value *To, Value *User) const;
-  void addAdditionalUsers(ExprResult &Res, Value *User) const;
+  void addAdditionalUsers(ExprResult &Res, Instruction *User) const;
 
   // Main loop of value numbering
   void iterateTouchedInstructions();
@@ -1540,17 +1549,17 @@ const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I) const {
   return LE;
 }
 
-const Expression *
+NewGVN::ExprResult
 NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const {
   auto *PI = PredInfo->getPredicateInfoFor(I);
   if (!PI)
-    return nullptr;
+    return ExprResult::none();
 
   LLVM_DEBUG(dbgs() << "Found predicate info from instruction !\n");
 
   const Optional<PredicateConstraint> &Constraint = PI->getConstraint();
   if (!Constraint)
-    return nullptr;
+    return ExprResult::none();
 
   CmpInst::Predicate Predicate = Constraint->Predicate;
   Value *CmpOp0 = I->getOperand(0);
@@ -1567,45 +1576,43 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const {
     AdditionallyUsedValue = CmpOp1;
   }
 
-  if (Predicate == CmpInst::ICMP_EQ) {
-    addPredicateUsers(PI, I);
-    addAdditionalUsers(AdditionallyUsedValue, I);
-    return createVariableOrConstant(FirstOp);
-  }
+  if (Predicate == CmpInst::ICMP_EQ)
+    return ExprResult::some(createVariableOrConstant(FirstOp),
+                            AdditionallyUsedValue, PI);
 
   // Handle the special case of floating point.
   if (Predicate == CmpInst::FCMP_OEQ && isa<ConstantFP>(FirstOp) &&
-      !cast<ConstantFP>(FirstOp)->isZero()) {
-    addPredicateUsers(PI, I);
-    addAdditionalUsers(AdditionallyUsedValue, I);
-    return createConstantExpression(cast<Constant>(FirstOp));
-  }
+      !cast<ConstantFP>(FirstOp)->isZero())
+    return ExprResult::some(createConstantExpression(cast<Constant>(FirstOp)),
+                            AdditionallyUsedValue, PI);
 
-  return nullptr;
+  return ExprResult::none();
 }
 
 // Evaluate read only and pure calls, and create an expression result.
-const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) const {
+NewGVN::ExprResult NewGVN::performSymbolicCallEvaluation(Instruction *I) const {
   auto *CI = cast<CallInst>(I);
   if (auto *II = dyn_cast<IntrinsicInst>(I)) {
     // Intrinsics with the returned attribute are copies of arguments.
     if (auto *ReturnedValue = II->getReturnedArgOperand()) {
       if (II->getIntrinsicID() == Intrinsic::ssa_copy)
-        if (const auto *Result = performSymbolicPredicateInfoEvaluation(I))
-          return Result;
-      return createVariableOrConstant(ReturnedValue);
+        if (auto Res = performSymbolicPredicateInfoEvaluation(I))
+          return Res;
+      return ExprResult::some(createVariableOrConstant(ReturnedValue));
     }
   }
   if (AA->doesNotAccessMemory(CI)) {
-    return createCallExpression(CI, TOPClass->getMemoryLeader());
+    return ExprResult::some(
+        createCallExpression(CI, TOPClass->getMemoryLeader()));
   } else if (AA->onlyReadsMemory(CI)) {
     if (auto *MA = MSSA->getMemoryAccess(CI)) {
       auto *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(MA);
-      return createCallExpression(CI, DefiningAccess);
+      return ExprResult::some(createCallExpression(CI, DefiningAccess));
     } else // MSSA determined that CI does not access memory.
-      return createCallExpression(CI, TOPClass->getMemoryLeader());
+      return ExprResult::some(
+          createCallExpression(CI, TOPClass->getMemoryLeader()));
   }
-  return nullptr;
+  return ExprResult::none();
 }
 
 // Retrieve the memory class for a given MemoryAccess.
@@ -1880,31 +1887,31 @@ NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
           // edge then we may be implied true or false.
           if (CmpInst::isImpliedTrueByMatchingCmp(BranchPredicate,
                                                   OurPredicate)) {
-            addPredicateUsers(PI, I);
             return ExprResult::some(
-                createConstantExpression(ConstantInt::getTrue(CI->getType())));
+                createConstantExpression(ConstantInt::getTrue(CI->getType())),
+                PI);
           }
 
           if (CmpInst::isImpliedFalseByMatchingCmp(BranchPredicate,
                                                    OurPredicate)) {
-            addPredicateUsers(PI, I);
             return ExprResult::some(
-                createConstantExpression(ConstantInt::getFalse(CI->getType())));
+                createConstantExpression(ConstantInt::getFalse(CI->getType())),
+                PI);
           }
         } else {
           // Just handle the ne and eq cases, where if we have the same
           // operands, we may know something.
           if (BranchPredicate == OurPredicate) {
-            addPredicateUsers(PI, I);
             // Same predicate, same ops,we know it was false, so this is false.
             return ExprResult::some(
-                createConstantExpression(ConstantInt::getFalse(CI->getType())));
+                createConstantExpression(ConstantInt::getFalse(CI->getType())),
+                PI);
           } else if (BranchPredicate ==
                      CmpInst::getInversePredicate(OurPredicate)) {
-            addPredicateUsers(PI, I);
             // Inverse predicate, we know the other was false, so this is true.
             return ExprResult::some(
-                createConstantExpression(ConstantInt::getTrue(CI->getType())));
+                createConstantExpression(ConstantInt::getTrue(CI->getType())),
+                PI);
           }
         }
       }
@@ -1944,7 +1951,7 @@ NewGVN::performSymbolicEvaluation(Value *V,
       E = performSymbolicPHIEvaluation(Ops, I, getBlockForValue(I));
     } break;
     case Instruction::Call:
-      E = performSymbolicCallEvaluation(I);
+      return performSymbolicCallEvaluation(I);
       break;
     case Instruction::Store:
       E = performSymbolicStoreEvaluation(I);
@@ -2024,10 +2031,18 @@ void NewGVN::addAdditionalUsers(Value *To, Value *User) const {
     AdditionalUsers[To].insert(User);
 }
 
-void NewGVN::addAdditionalUsers(ExprResult &Res, Value *User) const {
+void NewGVN::addAdditionalUsers(ExprResult &Res, Instruction *User) const {
   if (Res.ExtraDep && Res.ExtraDep != User)
     addAdditionalUsers(Res.ExtraDep, User);
   Res.ExtraDep = nullptr;
+
+  if (Res.PredDep) {
+    if (const auto *PBranch = dyn_cast<PredicateBranch>(Res.PredDep))
+      PredicateToUsers[PBranch->Condition].insert(User);
+    else if (const auto *PAssume = dyn_cast<PredicateAssume>(Res.PredDep))
+      PredicateToUsers[PAssume->Condition].insert(User);
+  }
+  Res.PredDep = nullptr;
 }
 
 void NewGVN::markUsersTouched(Value *V) {
@@ -2056,18 +2071,6 @@ void NewGVN::markMemoryUsersTouched(const MemoryAccess *MA) {
   touchAndErase(MemoryToUsers, MA);
 }
 
-// Add I to the set of users of a given predicate.
-void NewGVN::addPredicateUsers(const PredicateBase *PB, Instruction *I) const {
-  // Don't add temporary instructions to the user lists.
-  if (AllTempInstructions.count(I))
-    return;
-
-  if (auto *PBranch = dyn_cast<PredicateBranch>(PB))
-    PredicateToUsers[PBranch->Condition].insert(I);
-  else if (auto *PAssume = dyn_cast<PredicateAssume>(PB))
-    PredicateToUsers[PAssume->Condition].insert(I);
-}
-
 // Touch all the predicates that depend on this instruction.
 void NewGVN::markPredicateUsersTouched(Instruction *I) {
   touchAndErase(PredicateToUsers, I);

diff  --git a/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll b/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll
index c3409307d1898..26972a0ba06b7 100644
--- a/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll
+++ b/llvm/test/Transforms/NewGVN/phi-of-ops-simplification-dependencies.ll
@@ -116,3 +116,50 @@ latch:
 exit:
   ret void
 }
+
+define void @pr49873_cmp_simplification_dependency(i32* %ptr, i1 %c.0) {
+; CHECK-LABEL: @pr49873_cmp_simplification_dependency(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[LOOP_1:%.*]]
+; CHECK:       loop.1:
+; CHECK-NEXT:    br i1 [[C_0:%.*]], label [[LOOP_1_LATCH:%.*]], label [[LOOP_2:%.*]]
+; CHECK:       loop.2:
+; CHECK-NEXT:    [[I130:%.*]] = phi i32 [ [[I132:%.*]], [[LOOP_2]] ], [ 0, [[LOOP_1]] ]
+; CHECK-NEXT:    [[I132]] = add nuw i32 [[I130]], 1
+; CHECK-NEXT:    [[I133:%.*]] = load i32, i32* [[PTR:%.*]], align 4
+; CHECK-NEXT:    [[C_1:%.*]] = icmp ult i32 [[I132]], [[I133]]
+; CHECK-NEXT:    br i1 [[C_1]], label [[LOOP_2]], label [[LOOP_2_EXIT:%.*]]
+; CHECK:       loop.2.exit:
+; CHECK-NEXT:    br label [[LOOP_1_LATCH]]
+; CHECK:       loop.1.latch:
+; CHECK-NEXT:    [[DOTLCSSA:%.*]] = phi i32 [ 0, [[LOOP_1]] ], [ [[I133]], [[LOOP_2_EXIT]] ]
+; CHECK-NEXT:    [[C_2:%.*]] = icmp ult i32 1, [[DOTLCSSA]]
+; CHECK-NEXT:    br i1 [[C_2]], label [[LOOP_1]], label [[EXIT:%.*]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop.1
+
+loop.1:
+  %i65 = add nuw i32 0, 1
+  br i1 %c.0, label %loop.1.latch, label %loop.2
+
+loop.2:
+  %i130 = phi i32 [ %i132, %loop.2 ], [ 0, %loop.1 ]
+  %i132 = add nuw i32 %i130, 1
+  %i133 = load i32, i32* %ptr, align 4
+  %c.1 = icmp ult i32 %i132, %i133
+  br i1 %c.1, label %loop.2, label %loop.2.exit
+
+loop.2.exit:
+  br label %loop.1.latch
+
+loop.1.latch:                                      ; preds = %loop.2.exit, %loop.1
+  %.lcssa = phi i32 [ 0, %loop.1 ], [ %i133, %loop.2.exit ]
+  %c.2 = icmp ult i32 %i65, %.lcssa
+  br i1 %c.2, label %loop.1, label %exit
+
+exit:
+  ret void
+}


        


More information about the llvm-commits mailing list