[llvm] bcd8009 - [Attributor] Use the proper context instruction in genericValueTraversal

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 1 20:21:58 PDT 2020


Author: Johannes Doerfert
Date: 2020-04-01T22:20:47-05:00
New Revision: bcd8009369f86be8649307692f5ba57aad5ca29b

URL: https://github.com/llvm/llvm-project/commit/bcd8009369f86be8649307692f5ba57aad5ca29b
DIFF: https://github.com/llvm/llvm-project/commit/bcd8009369f86be8649307692f5ba57aad5ca29b.diff

LOG: [Attributor] Use the proper context instruction in genericValueTraversal

There was a TODO in genericValueTraversal to provide the context
instruction and due to the lack of it users that wanted one just used
something available. Unfortunately, using a fixed instruction is wrong
in the presence of PHIs so we need to update the context instruction
properly.

Reviewed By: uenoku

Differential Revision: https://reviews.llvm.org/D76870

Added: 
    

Modified: 
    llvm/lib/Transforms/IPO/Attributor.cpp
    llvm/test/Transforms/Attributor/range.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index b712ef75d26f..6a148f991b4e 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -398,8 +398,10 @@ static Value *constructPointer(Type *ResTy, Value *Ptr, int64_t Offset,
 template <typename AAType, typename StateTy>
 static bool genericValueTraversal(
     Attributor &A, IRPosition IRP, const AAType &QueryingAA, StateTy &State,
-    function_ref<bool(Value &, StateTy &, bool)> VisitValueCB,
-    int MaxValues = 8, function_ref<Value *(Value *)> StripCB = nullptr) {
+    function_ref<bool(Value &, const Instruction *, StateTy &, bool)>
+        VisitValueCB,
+    const Instruction *CtxI, int MaxValues = 16,
+    function_ref<Value *(Value *)> StripCB = nullptr) {
 
   const AAIsDead *LivenessAA = nullptr;
   if (IRP.getAnchorScope())
@@ -408,20 +410,22 @@ static bool genericValueTraversal(
         /* TrackDependence */ false);
   bool AnyDead = false;
 
-  // TODO: Use Positions here to allow context sensitivity in VisitValueCB
-  SmallPtrSet<Value *, 16> Visited;
-  SmallVector<Value *, 16> Worklist;
-  Worklist.push_back(&IRP.getAssociatedValue());
+  using Item = std::pair<Value *, const Instruction *>;
+  SmallSet<Item, 16> Visited;
+  SmallVector<Item, 16> Worklist;
+  Worklist.push_back({&IRP.getAssociatedValue(), CtxI});
 
   int Iteration = 0;
   do {
-    Value *V = Worklist.pop_back_val();
+    Item I = Worklist.pop_back_val();
+    Value *V = I.first;
+    CtxI = I.second;
     if (StripCB)
       V = StripCB(V);
 
     // Check if we should process the current value. To prevent endless
     // recursion keep a record of the values we followed!
-    if (!Visited.insert(V).second)
+    if (!Visited.insert(I).second)
       continue;
 
     // Make sure we limit the compile time for complex expressions.
@@ -444,14 +448,14 @@ static bool genericValueTraversal(
       }
     }
     if (NewV && NewV != V) {
-      Worklist.push_back(NewV);
+      Worklist.push_back({NewV, CtxI});
       continue;
     }
 
     // Look through select instructions, visit both potential values.
     if (auto *SI = dyn_cast<SelectInst>(V)) {
-      Worklist.push_back(SI->getTrueValue());
-      Worklist.push_back(SI->getFalseValue());
+      Worklist.push_back({SI->getTrueValue(), CtxI});
+      Worklist.push_back({SI->getFalseValue(), CtxI});
       continue;
     }
 
@@ -460,20 +464,21 @@ static bool genericValueTraversal(
       assert(LivenessAA &&
              "Expected liveness in the presence of instructions!");
       for (unsigned u = 0, e = PHI->getNumIncomingValues(); u < e; u++) {
-        const BasicBlock *IncomingBB = PHI->getIncomingBlock(u);
+        BasicBlock *IncomingBB = PHI->getIncomingBlock(u);
         if (A.isAssumedDead(*IncomingBB->getTerminator(), &QueryingAA,
                             LivenessAA,
                             /* CheckBBLivenessOnly */ true)) {
           AnyDead = true;
           continue;
         }
-        Worklist.push_back(PHI->getIncomingValue(u));
+        Worklist.push_back(
+            {PHI->getIncomingValue(u), IncomingBB->getTerminator()});
       }
       continue;
     }
 
     // Once a leaf is reached we inform the user through the callback.
-    if (!VisitValueCB(*V, State, Iteration > 1))
+    if (!VisitValueCB(*V, CtxI, State, Iteration > 1))
       return false;
   } while (!Worklist.empty());
 
@@ -710,7 +715,7 @@ void IRPosition::getAttrs(ArrayRef<Attribute::AttrKind> AKs,
   }
   if (A)
     for (Attribute::AttrKind AK : AKs)
-     getAttrsFromAssumes(AK, Attrs, *A);
+      getAttrsFromAssumes(AK, Attrs, *A);
 }
 
 bool IRPosition::getAttrsFromIRAttr(Attribute::AttrKind AK,
@@ -1466,7 +1471,8 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
   };
 
   // Callback for a leaf value returned by the associated function.
-  auto VisitValueCB = [](Value &Val, RVState &RVS, bool) -> bool {
+  auto VisitValueCB = [](Value &Val, const Instruction *, RVState &RVS,
+                         bool) -> bool {
     auto Size = RVS.RetValsMap[&Val].size();
     RVS.RetValsMap[&Val].insert(RVS.RetInsts.begin(), RVS.RetInsts.end());
     bool Inserted = RVS.RetValsMap[&Val].size() != Size;
@@ -1480,10 +1486,11 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
   };
 
   // Helper method to invoke the generic value traversal.
-  auto VisitReturnedValue = [&](Value &RV, RVState &RVS) {
+  auto VisitReturnedValue = [&](Value &RV, RVState &RVS,
+                                const Instruction *CtxI) {
     IRPosition RetValPos = IRPosition::value(RV);
-    return genericValueTraversal<AAReturnedValues, RVState>(A, RetValPos, *this,
-                                                            RVS, VisitValueCB);
+    return genericValueTraversal<AAReturnedValues, RVState>(
+        A, RetValPos, *this, RVS, VisitValueCB, CtxI);
   };
 
   // Callback for all "return intructions" live in the associated function.
@@ -1491,7 +1498,7 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
     ReturnInst &Ret = cast<ReturnInst>(I);
     RVState RVS({ReturnedValues, Changed, {}});
     RVS.RetInsts.insert(&Ret);
-    return VisitReturnedValue(*Ret.getReturnValue(), RVS);
+    return VisitReturnedValue(*Ret.getReturnValue(), RVS, &I);
   };
 
   // Start by discovering returned values from all live returned instructions in
@@ -1576,7 +1583,7 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
         // again.
         bool Unused = false;
         RVState RVS({NewRVsMap, Unused, RetValAAIt.second});
-        VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS);
+        VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS, CB);
         continue;
       } else if (isa<CallBase>(RetVal)) {
         // Call sites are resolved by the callee attribute over time, no need to
@@ -2148,11 +2155,11 @@ struct AANonNullFloating
       AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*Fn);
     }
 
-    auto VisitValueCB = [&](Value &V, AANonNull::StateType &T,
-                            bool Stripped) -> bool {
+    auto VisitValueCB = [&](Value &V, const Instruction *CtxI,
+                            AANonNull::StateType &T, bool Stripped) -> bool {
       const auto &AA = A.getAAFor<AANonNull>(*this, IRPosition::value(V));
       if (!Stripped && this == &AA) {
-        if (!isKnownNonZero(&V, DL, 0, AC, getCtxI(), DT))
+        if (!isKnownNonZero(&V, DL, 0, AC, CtxI, DT))
           T.indicatePessimisticFixpoint();
       } else {
         // Use abstract attribute information.
@@ -2164,8 +2171,8 @@ struct AANonNullFloating
     };
 
     StateType T;
-    if (!genericValueTraversal<AANonNull, StateType>(A, getIRPosition(), *this,
-                                                     T, VisitValueCB))
+    if (!genericValueTraversal<AANonNull, StateType>(
+            A, getIRPosition(), *this, T, VisitValueCB, getCtxI()))
       return indicatePessimisticFixpoint();
 
     return clampStateAndIndicateChange(getState(), T);
@@ -3776,7 +3783,8 @@ struct AADereferenceableFloating
 
     const DataLayout &DL = A.getDataLayout();
 
-    auto VisitValueCB = [&](Value &V, DerefState &T, bool Stripped) -> bool {
+    auto VisitValueCB = [&](Value &V, const Instruction *, DerefState &T,
+                            bool Stripped) -> bool {
       unsigned IdxWidth =
           DL.getIndexSizeInBits(V.getType()->getPointerAddressSpace());
       APInt Offset(IdxWidth, 0);
@@ -3831,7 +3839,7 @@ struct AADereferenceableFloating
 
     DerefState T;
     if (!genericValueTraversal<AADereferenceable, DerefState>(
-            A, getIRPosition(), *this, T, VisitValueCB))
+            A, getIRPosition(), *this, T, VisitValueCB, getCtxI()))
       return indicatePessimisticFixpoint();
 
     return Change | clampStateAndIndicateChange(getState(), T);
@@ -4073,8 +4081,8 @@ struct AAAlignFloating : AAFromMustBeExecutedContext<AAAlign, AAAlignImpl> {
 
     const DataLayout &DL = A.getDataLayout();
 
-    auto VisitValueCB = [&](Value &V, AAAlign::StateType &T,
-                            bool Stripped) -> bool {
+    auto VisitValueCB = [&](Value &V, const Instruction *,
+                            AAAlign::StateType &T, bool Stripped) -> bool {
       const auto &AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V));
       if (!Stripped && this == &AA) {
         // Use only IR information if we did not strip anything.
@@ -4092,7 +4100,7 @@ struct AAAlignFloating : AAFromMustBeExecutedContext<AAAlign, AAAlignImpl> {
 
     StateType T;
     if (!genericValueTraversal<AAAlign, StateType>(A, getIRPosition(), *this, T,
-                                                   VisitValueCB))
+                                                   VisitValueCB, getCtxI()))
       return indicatePessimisticFixpoint();
 
     // TODO: If we know we visited all incoming values, thus no are assumed
@@ -4958,7 +4966,8 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl {
   ChangeStatus updateImpl(Attributor &A) override {
     bool HasValueBefore = SimplifiedAssociatedValue.hasValue();
 
-    auto VisitValueCB = [&](Value &V, bool &, bool Stripped) -> bool {
+    auto VisitValueCB = [&](Value &V, const Instruction *CtxI, bool &,
+                            bool Stripped) -> bool {
       auto &AA = A.getAAFor<AAValueSimplify>(*this, IRPosition::value(V));
       if (!Stripped && this == &AA) {
         // TODO: Look the instruction and check recursively.
@@ -4971,8 +4980,8 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl {
     };
 
     bool Dummy = false;
-    if (!genericValueTraversal<AAValueSimplify, bool>(A, getIRPosition(), *this,
-                                                      Dummy, VisitValueCB))
+    if (!genericValueTraversal<AAValueSimplify, bool>(
+            A, getIRPosition(), *this, Dummy, VisitValueCB, getCtxI()))
       if (!askSimplifiedValueForAAValueConstantRange(A))
         return indicatePessimisticFixpoint();
 
@@ -6605,7 +6614,8 @@ void AAMemoryLocationImpl::categorizePtrValue(
     return V;
   };
 
-  auto VisitValueCB = [&](Value &V, AAMemoryLocation::StateType &T,
+  auto VisitValueCB = [&](Value &V, const Instruction *,
+                          AAMemoryLocation::StateType &T,
                           bool Stripped) -> bool {
     assert(!isa<GEPOperator>(V) && "GEPs should have been stripped.");
     if (isa<UndefValue>(V))
@@ -6652,7 +6662,7 @@ void AAMemoryLocationImpl::categorizePtrValue(
   };
 
   if (!genericValueTraversal<AAMemoryLocation, AAMemoryLocation::StateType>(
-          A, IRPosition::value(Ptr), *this, State, VisitValueCB,
+          A, IRPosition::value(Ptr), *this, State, VisitValueCB, getCtxI(),
           /* MaxValues */ 32, StripGEPCB)) {
     LLVM_DEBUG(
         dbgs() << "[AAMemoryLocation] Pointer locations not categorized\n");
@@ -7132,7 +7142,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
 
   bool calculateBinaryOperator(
       Attributor &A, BinaryOperator *BinOp, IntegerRangeState &T,
-      Instruction *CtxI,
+      const Instruction *CtxI,
       SmallVectorImpl<const AAValueConstantRange *> &QuerriedAAs) {
     Value *LHS = BinOp->getOperand(0);
     Value *RHS = BinOp->getOperand(1);
@@ -7160,7 +7170,8 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
   }
 
   bool calculateCastInst(
-      Attributor &A, CastInst *CastI, IntegerRangeState &T, Instruction *CtxI,
+      Attributor &A, CastInst *CastI, IntegerRangeState &T,
+      const Instruction *CtxI,
       SmallVectorImpl<const AAValueConstantRange *> &QuerriedAAs) {
     assert(CastI->getNumOperands() == 1 && "Expected cast to be unary!");
     // TODO: Allow non integers as well.
@@ -7178,7 +7189,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
 
   bool
   calculateCmpInst(Attributor &A, CmpInst *CmpI, IntegerRangeState &T,
-                   Instruction *CtxI,
+                   const Instruction *CtxI,
                    SmallVectorImpl<const AAValueConstantRange *> &QuerriedAAs) {
     Value *LHS = CmpI->getOperand(0);
     Value *RHS = CmpI->getOperand(1);
@@ -7233,9 +7244,8 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
 
   /// See AbstractAttribute::updateImpl(...).
   ChangeStatus updateImpl(Attributor &A) override {
-    Instruction *CtxI = getCtxI();
-    auto VisitValueCB = [&](Value &V, IntegerRangeState &T,
-                            bool Stripped) -> bool {
+    auto VisitValueCB = [&](Value &V, const Instruction *CtxI,
+                            IntegerRangeState &T, bool Stripped) -> bool {
       Instruction *I = dyn_cast<Instruction>(&V);
       if (!I || isa<CallBase>(I)) {
 
@@ -7285,7 +7295,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
     IntegerRangeState T(getBitWidth());
 
     if (!genericValueTraversal<AAValueConstantRange, IntegerRangeState>(
-            A, getIRPosition(), *this, T, VisitValueCB))
+            A, getIRPosition(), *this, T, VisitValueCB, getCtxI()))
       return indicatePessimisticFixpoint();
 
     return clampStateAndIndicateChange(getState(), T);

diff  --git a/llvm/test/Transforms/Attributor/range.ll b/llvm/test/Transforms/Attributor/range.ll
index a516053bfa85..efb2df556af6 100644
--- a/llvm/test/Transforms/Attributor/range.ll
+++ b/llvm/test/Transforms/Attributor/range.ll
@@ -1212,6 +1212,76 @@ define i1 @callee_range_2(i1 %c1, i1 %c2) {
 }
 
 
+define i32 @ret100() {
+; CHECK-LABEL: define {{[^@]+}}@ret100()
+; CHECK-NEXT:    ret i32 100
+;
+  ret i32 100
+}
+
+define i1 @ctx_adjustment(i32 %V) {
+; OLD_PM-LABEL: define {{[^@]+}}@ctx_adjustment
+; OLD_PM-SAME: (i32 [[V:%.*]])
+; OLD_PM-NEXT:    [[C1:%.*]] = icmp sge i32 [[V]], 100
+; OLD_PM-NEXT:    br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
+; OLD_PM:       if.true:
+; OLD_PM-NEXT:    br label [[END:%.*]]
+; OLD_PM:       if.false:
+; OLD_PM-NEXT:    br label [[END]]
+; OLD_PM:       end:
+; OLD_PM-NEXT:    [[PHI:%.*]] = phi i32 [ [[V]], [[IF_TRUE]] ], [ 100, [[IF_FALSE]] ]
+; OLD_PM-NEXT:    [[C2:%.*]] = icmp sge i32 [[PHI]], 100
+; OLD_PM-NEXT:    ret i1 [[C2]]
+;
+; NEW_PM-LABEL: define {{[^@]+}}@ctx_adjustment
+; NEW_PM-SAME: (i32 [[V:%.*]])
+; NEW_PM-NEXT:    [[C1:%.*]] = icmp sge i32 [[V]], 100
+; NEW_PM-NEXT:    br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
+; NEW_PM:       if.true:
+; NEW_PM-NEXT:    br label [[END:%.*]]
+; NEW_PM:       if.false:
+; NEW_PM-NEXT:    br label [[END]]
+; NEW_PM:       end:
+; NEW_PM-NEXT:    ret i1 true
+;
+; CGSCC_OLD_PM-LABEL: define {{[^@]+}}@ctx_adjustment
+; CGSCC_OLD_PM-SAME: (i32 [[V:%.*]])
+; CGSCC_OLD_PM-NEXT:    [[C1:%.*]] = icmp sge i32 [[V]], 100
+; CGSCC_OLD_PM-NEXT:    br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
+; CGSCC_OLD_PM:       if.true:
+; CGSCC_OLD_PM-NEXT:    br label [[END:%.*]]
+; CGSCC_OLD_PM:       if.false:
+; CGSCC_OLD_PM-NEXT:    br label [[END]]
+; CGSCC_OLD_PM:       end:
+; CGSCC_OLD_PM-NEXT:    [[PHI:%.*]] = phi i32 [ [[V]], [[IF_TRUE]] ], [ 100, [[IF_FALSE]] ]
+; CGSCC_OLD_PM-NEXT:    [[C2:%.*]] = icmp sge i32 [[PHI]], 100
+; CGSCC_OLD_PM-NEXT:    ret i1 [[C2]]
+;
+; CGSCC_NEW_PM-LABEL: define {{[^@]+}}@ctx_adjustment
+; CGSCC_NEW_PM-SAME: (i32 [[V:%.*]])
+; CGSCC_NEW_PM-NEXT:    [[C1:%.*]] = icmp sge i32 [[V]], 100
+; CGSCC_NEW_PM-NEXT:    br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
+; CGSCC_NEW_PM:       if.true:
+; CGSCC_NEW_PM-NEXT:    br label [[END:%.*]]
+; CGSCC_NEW_PM:       if.false:
+; CGSCC_NEW_PM-NEXT:    br label [[END]]
+; CGSCC_NEW_PM:       end:
+; CGSCC_NEW_PM-NEXT:    ret i1 true
+;
+  %c1 = icmp sge i32 %V, 100
+  br i1 %c1, label %if.true, label %if.false
+if.true:
+  br label %end
+if.false:
+  %call = call i32 @ret100()
+  br label %end
+end:
+  %phi = phi i32 [ %V, %if.true ], [ %call, %if.false ]
+  %c2 = icmp sge i32 %phi, 100
+  ret i1 %c2
+}
+
+
 !0 = !{i32 0, i32 10}
 !1 = !{i32 10, i32 100}
 ; CHECK: !0 = !{i32 0, i32 10}


        


More information about the llvm-commits mailing list