[llvm] d75c9e6 - [Attributor] Attributor call site specific AAValueConstantRange

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 10 14:20:21 PST 2021


Author: kuterd
Date: 2021-03-11T01:19:44+03:00
New Revision: d75c9e61a5a3a50203f68f7e329c9c670d267a98

URL: https://github.com/llvm/llvm-project/commit/d75c9e61a5a3a50203f68f7e329c9c670d267a98
DIFF: https://github.com/llvm/llvm-project/commit/d75c9e61a5a3a50203f68f7e329c9c670d267a98.diff

LOG: [Attributor] Attributor call site specific AAValueConstantRange

This patch makes uses of the context bridges introduced in D83299 to make
AAValueConstantRange call site specific.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D83744

Added: 
    llvm/test/Transforms/Attributor/cb_range_disabled.ll
    llvm/test/Transforms/Attributor/cb_range_enabled.ll

Modified: 
    llvm/include/llvm/Transforms/IPO/Attributor.h
    llvm/lib/Transforms/IPO/Attributor.cpp
    llvm/lib/Transforms/IPO/AttributorAttributes.cpp
    llvm/unittests/Transforms/IPO/AttributorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index fb831a6cb926..86f08ac563f2 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -229,6 +229,9 @@ struct AADepGraph {
 /// are floating values that do not have a corresponding attribute list
 /// position.
 struct IRPosition {
+  // NOTE: In the future this definition can be changed to support recursive
+  // functions.
+  using CallBaseContext = CallBase;
 
   /// The positions we distinguish in the IR.
   enum Kind : char {
@@ -249,27 +252,34 @@ struct IRPosition {
   IRPosition() : Enc(nullptr, ENC_VALUE) { verify(); }
 
   /// Create a position describing the value of \p V.
-  static const IRPosition value(const Value &V) {
+  static const IRPosition value(const Value &V,
+                                const CallBaseContext *CBContext = nullptr) {
     if (auto *Arg = dyn_cast<Argument>(&V))
-      return IRPosition::argument(*Arg);
+      return IRPosition::argument(*Arg, CBContext);
     if (auto *CB = dyn_cast<CallBase>(&V))
       return IRPosition::callsite_returned(*CB);
-    return IRPosition(const_cast<Value &>(V), IRP_FLOAT);
+    return IRPosition(const_cast<Value &>(V), IRP_FLOAT, CBContext);
   }
 
   /// Create a position describing the function scope of \p F.
-  static const IRPosition function(const Function &F) {
-    return IRPosition(const_cast<Function &>(F), IRP_FUNCTION);
+  /// \p CBContext is used for call base specific analysis.
+  static const IRPosition function(const Function &F,
+                                   const CallBaseContext *CBContext = nullptr) {
+    return IRPosition(const_cast<Function &>(F), IRP_FUNCTION, CBContext);
   }
 
   /// Create a position describing the returned value of \p F.
-  static const IRPosition returned(const Function &F) {
-    return IRPosition(const_cast<Function &>(F), IRP_RETURNED);
+  /// \p CBContext is used for call base specific analysis.
+  static const IRPosition returned(const Function &F,
+                                   const CallBaseContext *CBContext = nullptr) {
+    return IRPosition(const_cast<Function &>(F), IRP_RETURNED, CBContext);
   }
 
   /// Create a position describing the argument \p Arg.
-  static const IRPosition argument(const Argument &Arg) {
-    return IRPosition(const_cast<Argument &>(Arg), IRP_ARGUMENT);
+  /// \p CBContext is used for call base specific analysis.
+  static const IRPosition argument(const Argument &Arg,
+                                   const CallBaseContext *CBContext = nullptr) {
+    return IRPosition(const_cast<Argument &>(Arg), IRP_ARGUMENT, CBContext);
   }
 
   /// Create a position describing the function scope of \p CB.
@@ -305,16 +315,20 @@ struct IRPosition {
   /// If \p IRP is a call site (see isAnyCallSitePosition()) then the result
   /// will be a call site position, otherwise the function position of the
   /// associated function.
-  static const IRPosition function_scope(const IRPosition &IRP) {
+  static const IRPosition
+  function_scope(const IRPosition &IRP,
+                 const CallBaseContext *CBContext = nullptr) {
     if (IRP.isAnyCallSitePosition()) {
       return IRPosition::callsite_function(
           cast<CallBase>(IRP.getAnchorValue()));
     }
     assert(IRP.getAssociatedFunction());
-    return IRPosition::function(*IRP.getAssociatedFunction());
+    return IRPosition::function(*IRP.getAssociatedFunction(), CBContext);
   }
 
-  bool operator==(const IRPosition &RHS) const { return Enc == RHS.Enc; }
+  bool operator==(const IRPosition &RHS) const {
+    return Enc == RHS.Enc && RHS.CBContext == CBContext;
+  }
   bool operator!=(const IRPosition &RHS) const { return !(*this == RHS); }
 
   /// Return the value this abstract attribute is anchored with.
@@ -535,6 +549,19 @@ struct IRPosition {
     }
   }
 
+  /// Return the same position without the call base context.
+  IRPosition stripCallBaseContext() const {
+    IRPosition Result = *this;
+    Result.CBContext = nullptr;
+    return Result;
+  }
+
+  /// Get the call base context from the position.
+  const CallBaseContext *getCallBaseContext() const { return CBContext; }
+
+  /// Check if the position has any call base context.
+  bool hasCallBaseContext() const { return CBContext != nullptr; }
+
   /// Special DenseMap key values.
   ///
   ///{
@@ -547,10 +574,15 @@ struct IRPosition {
 
 private:
   /// Private constructor for special values only!
-  explicit IRPosition(void *Ptr) { Enc.setFromOpaqueValue(Ptr); }
+  explicit IRPosition(void *Ptr, const CallBaseContext *CBContext = nullptr)
+      : CBContext(CBContext) {
+    Enc.setFromOpaqueValue(Ptr);
+  }
 
   /// IRPosition anchored at \p AnchorVal with kind/argument numbet \p PK.
-  explicit IRPosition(Value &AnchorVal, Kind PK) {
+  explicit IRPosition(Value &AnchorVal, Kind PK,
+                      const CallBaseContext *CBContext = nullptr)
+      : CBContext(CBContext) {
     switch (PK) {
     case IRPosition::IRP_INVALID:
       llvm_unreachable("Cannot create invalid IRP with an anchor value!");
@@ -672,16 +704,27 @@ struct IRPosition {
   PointerIntPair<void *, NumEncodingBits, char> Enc;
   ///}
 
+  /// Call base context. Used for callsite specific analysis.
+  const CallBaseContext *CBContext = nullptr;
+
   /// Return the encoding bits.
   char getEncodingBits() const { return Enc.getInt(); }
 };
 
 /// Helper that allows IRPosition as a key in a DenseMap.
-template <> struct DenseMapInfo<IRPosition> : DenseMapInfo<void *> {
+template <> struct DenseMapInfo<IRPosition> {
   static inline IRPosition getEmptyKey() { return IRPosition::EmptyKey; }
   static inline IRPosition getTombstoneKey() {
     return IRPosition::TombstoneKey;
   }
+  static unsigned getHashValue(const IRPosition &IRP) {
+    return (DenseMapInfo<void *>::getHashValue(IRP) << 4) ^
+           (DenseMapInfo<Value *>::getHashValue(IRP.getCallBaseContext()));
+  }
+
+  static bool isEqual(const IRPosition &a, const IRPosition &b) {
+    return a == b;
+  }
 };
 
 /// A visitor class for IR positions.
@@ -1080,8 +1123,22 @@ struct Attributor {
   /// NOTE: ForceUpdate is ignored in any stage other than the update stage.
   template <typename AAType>
   const AAType &
-  getOrCreateAAFor(const IRPosition &IRP, const AbstractAttribute *QueryingAA,
+  getOrCreateAAFor(IRPosition IRP, const AbstractAttribute *QueryingAA,
                    DepClassTy DepClass, bool ForceUpdate = false) {
+#ifdef EXPENSIVE_CHECKS
+    // Don't allow callbase information to leak.
+    if (auto CBContext = IRP.getCallBaseContext()) {
+      assert(
+          ((CBContext->getCalledFunction() == IRP.getAnchorScope() ||
+            QueryingAA ||
+            !QueryingAA.getIRPosition().isAnyCallSitePosition())) &&
+          "non callsite positions are not allowed to propagate CallBaseContext "
+          "across functions");
+    }
+#endif
+    if (!shouldPropagateCallBaseContext(IRP))
+      IRP = IRP.stripCallBaseContext();
+
     if (AAType *AAPtr = lookupAAFor<AAType>(IRP, QueryingAA, DepClass)) {
       if (ForceUpdate && Phase == AttributorPhase::UPDATE)
         updateAA(*AAPtr);
@@ -1600,6 +1657,9 @@ struct Attributor {
                             const AbstractAttribute *QueryingAA,
                             bool &AllCallSitesKnown);
 
+  /// Determine if CallBase context in \p IRP should be propagated.
+  bool shouldPropagateCallBaseContext(const IRPosition &IRP);
+
   /// Apply all requested function signature rewrites
   /// (\see registerFunctionSignatureRewrite) and return Changed if the module
   /// was altered.

diff  --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp
index d2b4c7bc3536..eef780d29704 100644
--- a/llvm/lib/Transforms/IPO/Attributor.cpp
+++ b/llvm/lib/Transforms/IPO/Attributor.cpp
@@ -142,6 +142,11 @@ static cl::opt<bool> PrintDependencies("attributor-print-dep", cl::Hidden,
                                        cl::desc("Print attribute dependencies"),
                                        cl::init(false));
 
+static cl::opt<bool> EnableCallSiteSpecific(
+    "attributor-enable-call-site-specific-deduction", cl::Hidden,
+    cl::desc("Allow the Attributor to do call site specific analysis"),
+    cl::init(false));
+
 /// Logic operators for the change status enum class.
 ///
 ///{
@@ -476,6 +481,8 @@ void IRPosition::verify() {
 #ifdef EXPENSIVE_CHECKS
   switch (getPositionKind()) {
   case IRP_INVALID:
+    assert((CBContext == nullptr) &&
+           "Invalid position must not have CallBaseContext!");
     assert(!Enc.getOpaqueValue() &&
            "Expected a nullptr for an invalid position!");
     return;
@@ -491,12 +498,16 @@ void IRPosition::verify() {
            "Associated value mismatch!");
     return;
   case IRP_CALL_SITE_RETURNED:
+    assert((CBContext == nullptr) &&
+           "'call site returned' position must not have CallBaseContext!");
     assert((isa<CallBase>(getAsValuePtr())) &&
            "Expected call base for 'call site returned' position!");
     assert(getAsValuePtr() == &getAssociatedValue() &&
            "Associated value mismatch!");
     return;
   case IRP_CALL_SITE:
+    assert((CBContext == nullptr) &&
+           "'call site function' position must not have CallBaseContext!");
     assert((isa<CallBase>(getAsValuePtr())) &&
            "Expected call base for 'call site function' position!");
     assert(getAsValuePtr() == &getAssociatedValue() &&
@@ -515,6 +526,8 @@ void IRPosition::verify() {
            "Associated value mismatch!");
     return;
   case IRP_CALL_SITE_ARGUMENT: {
+    assert((CBContext == nullptr) &&
+           "'call site argument' position must not have CallBaseContext!");
     Use *U = getAsUsePtr();
     assert(U && "Expected use for a 'call site argument' position!");
     assert(isa<CallBase>(U->getUser()) &&
@@ -849,6 +862,13 @@ bool Attributor::checkForAllCallSites(function_ref<bool(AbstractCallSite)> Pred,
   return true;
 }
 
+bool Attributor::shouldPropagateCallBaseContext(const IRPosition &IRP) {
+  // TODO: Maintain a cache of Values that are
+  // on the pathway from a Argument to a Instruction that would effect the
+  // liveness/return state etc.
+  return EnableCallSiteSpecific;
+}
+
 bool Attributor::checkForAllReturnedValuesAndReturnInsts(
     function_ref<bool(Value &, const SmallSetVector<ReturnInst *, 4> &)> Pred,
     const AbstractAttribute &QueryingAA) {
@@ -1125,6 +1145,9 @@ ChangeStatus Attributor::manifestAttributes() {
     if (!State.isAtFixpoint())
       State.indicateOptimisticFixpoint();
 
+    // We must not manifest Attributes that use Callbase info.
+    if (AA->hasCallBaseContext())
+      continue;
     // If the state is invalid, we do not try to manifest it.
     if (!State.isValidState())
       continue;
@@ -2221,9 +2244,12 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, IRPosition::Kind AP) {
 
 raw_ostream &llvm::operator<<(raw_ostream &OS, const IRPosition &Pos) {
   const Value &AV = Pos.getAssociatedValue();
-  return OS << "{" << Pos.getPositionKind() << ":" << AV.getName() << " ["
-            << Pos.getAnchorValue().getName() << "@" << Pos.getCallSiteArgNo()
-            << "]}";
+  OS << "{" << Pos.getPositionKind() << ":" << AV.getName() << " ["
+     << Pos.getAnchorValue().getName() << "@" << Pos.getCallSiteArgNo() << "]";
+
+  if (Pos.hasCallBaseContext())
+    OS << "[cb_context:" << *Pos.getCallBaseContext() << "]";
+  return OS << "}";
 }
 
 raw_ostream &llvm::operator<<(raw_ostream &OS, const IntegerRangeState &S) {

diff  --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 1c0fbdb8cba0..3c1f16f799eb 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -431,8 +431,9 @@ ChangeStatus clampStateAndIndicateChange(StateType &S, const StateType &R) {
 /// Clamp the information known for all returned values of a function
 /// (identified by \p QueryingAA) into \p S.
 template <typename AAType, typename StateType = typename AAType::StateType>
-static void clampReturnedValueStates(Attributor &A, const AAType &QueryingAA,
-                                     StateType &S) {
+static void clampReturnedValueStates(
+    Attributor &A, const AAType &QueryingAA, StateType &S,
+    const IRPosition::CallBaseContext *CBContext = nullptr) {
   LLVM_DEBUG(dbgs() << "[Attributor] Clamp return value states for "
                     << QueryingAA << " into " << S << "\n");
 
@@ -449,7 +450,7 @@ static void clampReturnedValueStates(Attributor &A, const AAType &QueryingAA,
 
   // Callback for each possibly returned value.
   auto CheckReturnValue = [&](Value &RV) -> bool {
-    const IRPosition &RVPos = IRPosition::value(RV);
+    const IRPosition &RVPos = IRPosition::value(RV, CBContext);
     const AAType &AA =
         A.getAAFor<AAType>(QueryingAA, RVPos, DepClassTy::REQUIRED);
     LLVM_DEBUG(dbgs() << "[Attributor] RV: " << RV << " AA: " << AA.getAsStr()
@@ -472,7 +473,8 @@ static void clampReturnedValueStates(Attributor &A, const AAType &QueryingAA,
 
 /// Helper class for generic deduction: return value -> returned position.
 template <typename AAType, typename BaseType,
-          typename StateType = typename BaseType::StateType>
+          typename StateType = typename BaseType::StateType,
+          bool PropagateCallBaseContext = false>
 struct AAReturnedFromReturnedValues : public BaseType {
   AAReturnedFromReturnedValues(const IRPosition &IRP, Attributor &A)
       : BaseType(IRP, A) {}
@@ -480,7 +482,9 @@ struct AAReturnedFromReturnedValues : public BaseType {
   /// See AbstractAttribute::updateImpl(...).
   ChangeStatus updateImpl(Attributor &A) override {
     StateType S(StateType::getBestState(this->getState()));
-    clampReturnedValueStates<AAType, StateType>(A, *this, S);
+    clampReturnedValueStates<AAType, StateType>(
+        A, *this, S,
+        PropagateCallBaseContext ? this->getCallBaseContext() : nullptr);
     // TODO: If we know we visited all returned values, thus no are assumed
     // dead, we can take the known information from the state T.
     return clampStateAndIndicateChange<StateType>(this->getState(), S);
@@ -535,17 +539,58 @@ static void clampCallSiteArgumentStates(Attributor &A, const AAType &QueryingAA,
     S ^= *T;
 }
 
-/// Helper class for generic deduction: call site argument -> argument position.
+/// This function is the bridge between argument position and the call base
+/// context.
 template <typename AAType, typename BaseType,
           typename StateType = typename AAType::StateType>
+bool getArgumentStateFromCallBaseContext(Attributor &A,
+                                         BaseType &QueryingAttribute,
+                                         IRPosition &Pos, StateType &State) {
+  assert((Pos.getPositionKind() == IRPosition::IRP_ARGUMENT) &&
+         "Expected an 'argument' position !");
+  const CallBase *CBContext = Pos.getCallBaseContext();
+  if (!CBContext)
+    return false;
+
+  int ArgNo = Pos.getCallSiteArgNo();
+  assert(ArgNo >= 0 && "Invalid Arg No!");
+
+  const auto &AA = A.getAAFor<AAType>(
+      QueryingAttribute, IRPosition::callsite_argument(*CBContext, ArgNo),
+      DepClassTy::REQUIRED);
+  const StateType &CBArgumentState =
+      static_cast<const StateType &>(AA.getState());
+
+  LLVM_DEBUG(dbgs() << "[Attributor] Briding Call site context to argument"
+                    << "Position:" << Pos << "CB Arg state:" << CBArgumentState
+                    << "\n");
+
+  // NOTE: If we want to do call site grouping it should happen here.
+  State ^= CBArgumentState;
+  return true;
+}
+
+/// Helper class for generic deduction: call site argument -> argument position.
+template <typename AAType, typename BaseType,
+          typename StateType = typename AAType::StateType,
+          bool BridgeCallBaseContext = false>
 struct AAArgumentFromCallSiteArguments : public BaseType {
   AAArgumentFromCallSiteArguments(const IRPosition &IRP, Attributor &A)
       : BaseType(IRP, A) {}
 
   /// See AbstractAttribute::updateImpl(...).
   ChangeStatus updateImpl(Attributor &A) override {
-    StateType S(StateType::getBestState(this->getState()));
+    StateType S = StateType::getBestState(this->getState());
+
+    if (BridgeCallBaseContext) {
+      bool Success =
+          getArgumentStateFromCallBaseContext<AAType, BaseType, StateType>(
+              A, *this, this->getIRPosition(), S);
+      if (Success)
+        return clampStateAndIndicateChange<StateType>(this->getState(), S);
+    }
     clampCallSiteArgumentStates<AAType, StateType>(A, *this, S);
+
     // TODO: If we know we visited all incoming values, thus no are assumed
     // dead, we can take the known information from the state T.
     return clampStateAndIndicateChange<StateType>(this->getState(), S);
@@ -554,7 +599,8 @@ struct AAArgumentFromCallSiteArguments : public BaseType {
 
 /// Helper class for generic replication: function returned -> cs returned.
 template <typename AAType, typename BaseType,
-          typename StateType = typename BaseType::StateType>
+          typename StateType = typename BaseType::StateType,
+          bool IntroduceCallBaseContext = false>
 struct AACallSiteReturnedFromReturned : public BaseType {
   AACallSiteReturnedFromReturned(const IRPosition &IRP, Attributor &A)
       : BaseType(IRP, A) {}
@@ -572,7 +618,13 @@ struct AACallSiteReturnedFromReturned : public BaseType {
     if (!AssociatedFunction)
       return S.indicatePessimisticFixpoint();
 
-    IRPosition FnPos = IRPosition::returned(*AssociatedFunction);
+    CallBase &CBContext = static_cast<CallBase &>(this->getAnchorValue());
+    if (IntroduceCallBaseContext)
+      LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:"
+                        << CBContext << "\n");
+
+    IRPosition FnPos = IRPosition::returned(
+        *AssociatedFunction, IntroduceCallBaseContext ? &CBContext : nullptr);
     const AAType &AA = A.getAAFor<AAType>(*this, FnPos, DepClassTy::REQUIRED);
     return clampStateAndIndicateChange(S, AA.getState());
   }
@@ -7126,9 +7178,11 @@ struct AAValueConstantRangeImpl : AAValueConstantRange {
 
 struct AAValueConstantRangeArgument final
     : AAArgumentFromCallSiteArguments<
-          AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState> {
+          AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState,
+          true /* BridgeCallBaseContext */> {
   using Base = AAArgumentFromCallSiteArguments<
-      AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState>;
+      AAValueConstantRange, AAValueConstantRangeImpl, IntegerRangeState,
+      true /* BridgeCallBaseContext */>;
   AAValueConstantRangeArgument(const IRPosition &IRP, Attributor &A)
       : Base(IRP, A) {}
 
@@ -7149,9 +7203,14 @@ struct AAValueConstantRangeArgument final
 
 struct AAValueConstantRangeReturned
     : AAReturnedFromReturnedValues<AAValueConstantRange,
-                                   AAValueConstantRangeImpl> {
-  using Base = AAReturnedFromReturnedValues<AAValueConstantRange,
-                                            AAValueConstantRangeImpl>;
+                                   AAValueConstantRangeImpl,
+                                   AAValueConstantRangeImpl::StateType,
+                                   /* PropogateCallBaseContext */ true> {
+  using Base =
+      AAReturnedFromReturnedValues<AAValueConstantRange,
+                                   AAValueConstantRangeImpl,
+                                   AAValueConstantRangeImpl::StateType,
+                                   /* PropogateCallBaseContext */ true>;
   AAValueConstantRangeReturned(const IRPosition &IRP, Attributor &A)
       : Base(IRP, A) {}
 
@@ -7221,12 +7280,14 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
       return false;
 
     auto &LHSAA = A.getAAFor<AAValueConstantRange>(
-        *this, IRPosition::value(*LHS), DepClassTy::REQUIRED);
+        *this, IRPosition::value(*LHS, getCallBaseContext()),
+        DepClassTy::REQUIRED);
     QuerriedAAs.push_back(&LHSAA);
     auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI);
 
     auto &RHSAA = A.getAAFor<AAValueConstantRange>(
-        *this, IRPosition::value(*RHS), DepClassTy::REQUIRED);
+        *this, IRPosition::value(*RHS, getCallBaseContext()),
+        DepClassTy::REQUIRED);
     QuerriedAAs.push_back(&RHSAA);
     auto RHSAARange = RHSAA.getAssumedConstantRange(A, CtxI);
 
@@ -7249,8 +7310,9 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
     if (!OpV.getType()->isIntegerTy())
       return false;
 
-    auto &OpAA = A.getAAFor<AAValueConstantRange>(*this, IRPosition::value(OpV),
-                                                  DepClassTy::REQUIRED);
+    auto &OpAA = A.getAAFor<AAValueConstantRange>(
+        *this, IRPosition::value(OpV, getCallBaseContext()),
+        DepClassTy::REQUIRED);
     QuerriedAAs.push_back(&OpAA);
     T.unionAssumed(
         OpAA.getAssumed().castOp(CastI->getOpcode(), getState().getBitWidth()));
@@ -7268,12 +7330,12 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
       return false;
 
     auto &LHSAA = A.getAAFor<AAValueConstantRange>(
-        *this, IRPosition::value(*LHS), DepClassTy::REQUIRED);
+        *this, IRPosition::value(*LHS, getCallBaseContext()),
+        DepClassTy::REQUIRED);
     QuerriedAAs.push_back(&LHSAA);
     auto &RHSAA = A.getAAFor<AAValueConstantRange>(
-        *this, IRPosition::value(*RHS), DepClassTy::REQUIRED);
-    QuerriedAAs.push_back(&RHSAA);
-
+        *this, IRPosition::value(*RHS, getCallBaseContext()),
+        DepClassTy::REQUIRED);
     auto LHSAARange = LHSAA.getAssumedConstantRange(A, CtxI);
     auto RHSAARange = RHSAA.getAssumedConstantRange(A, CtxI);
 
@@ -7402,10 +7464,16 @@ struct AAValueConstantRangeCallSite : AAValueConstantRangeFunction {
 
 struct AAValueConstantRangeCallSiteReturned
     : AACallSiteReturnedFromReturned<AAValueConstantRange,
-                                     AAValueConstantRangeImpl> {
+                                     AAValueConstantRangeImpl,
+                                     AAValueConstantRangeImpl::StateType,
+                                     /* IntroduceCallBaseContext */ true> {
   AAValueConstantRangeCallSiteReturned(const IRPosition &IRP, Attributor &A)
       : AACallSiteReturnedFromReturned<AAValueConstantRange,
-                                       AAValueConstantRangeImpl>(IRP, A) {}
+                                       AAValueConstantRangeImpl,
+                                       AAValueConstantRangeImpl::StateType,
+                                       /* IntroduceCallBaseContext */ true>(IRP,
+                                                                            A) {
+  }
 
   /// See AbstractAttribute::initialize(...).
   void initialize(Attributor &A) override {

diff  --git a/llvm/test/Transforms/Attributor/cb_range_disabled.ll b/llvm/test/Transforms/Attributor/cb_range_disabled.ll
new file mode 100644
index 000000000000..ab01f6f87ad6
--- /dev/null
+++ b/llvm/test/Transforms/Attributor/cb_range_disabled.ll
@@ -0,0 +1,132 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --scrub-attributes
+; call site specific analysis is disabled
+
+; RUN: opt -attributor -enable-new-pm=0 -attributor-manifest-internal  -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=1 -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_CGSCC_NPM,NOT_CGSCC_OPM,NOT_TUNIT_NPM,IS__TUNIT____,IS________OPM,IS__TUNIT_OPM
+
+; RUN: opt -aa-pipeline=basic-aa -passes=attributor -attributor-manifest-internal  -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=1 -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_CGSCC_OPM,NOT_CGSCC_NPM,NOT_TUNIT_OPM,IS__TUNIT____,IS________NPM,IS__TUNIT_NPM
+
+; RUN: opt -attributor-cgscc -enable-new-pm=0 -attributor-manifest-internal  -attributor-annotate-decl-cs -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_TUNIT_NPM,NOT_TUNIT_OPM,NOT_CGSCC_NPM,IS__CGSCC____,IS________OPM,IS__CGSCC_OPM
+
+; RUN: opt -aa-pipeline=basic-aa -passes=attributor-cgscc -attributor-manifest-internal  -attributor-annotate-decl-cs -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_TUNIT_NPM,NOT_TUNIT_OPM,NOT_CGSCC_OPM,IS__CGSCC____,IS________NPM,IS__CGSCC_NPM
+
+define i32 @test_range(i32 %unknown) {
+; CHECK-LABEL: define {{[^@]+}}@test_range
+; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0:#.*]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sgt i32 [[UNKNOWN]], 100
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[TMP1]], i32 100, i32 0
+; CHECK-NEXT:    ret i32 [[TMP2]]
+;
+  %1 = icmp sgt i32 %unknown, 100
+  %2 = select i1 %1, i32 100, i32 0
+  ret i32 %2
+}
+
+define i32 @test1(i32 %unknown, i32 %b) {
+; IS__TUNIT____-LABEL: define {{[^@]+}}@test1
+; IS__TUNIT____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) [[ATTR0:#.*]] {
+; IS__TUNIT____-NEXT:    [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) [[ATTR0]], [[RNG0:!range !.*]]
+; IS__TUNIT____-NEXT:    [[TMP2:%.*]] = sub nsw i32 [[TMP1]], [[B]]
+; IS__TUNIT____-NEXT:    ret i32 [[TMP2]]
+;
+; IS__CGSCC____-LABEL: define {{[^@]+}}@test1
+; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) [[ATTR0:#.*]] {
+; IS__CGSCC____-NEXT:    [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) [[ATTR1:#.*]], [[RNG0:!range !.*]]
+; IS__CGSCC____-NEXT:    [[TMP2:%.*]] = sub nsw i32 [[TMP1]], [[B]]
+; IS__CGSCC____-NEXT:    ret i32 [[TMP2]]
+;
+  %1 = call i32 @test_range(i32 %unknown)
+  %2 = sub nsw i32 %1, %b
+  ret i32 %2
+}
+
+define i32 @test2(i32 %unknown, i32 %b) {
+; IS__TUNIT____-LABEL: define {{[^@]+}}@test2
+; IS__TUNIT____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) [[ATTR0]] {
+; IS__TUNIT____-NEXT:    [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) [[ATTR0]], [[RNG0]]
+; IS__TUNIT____-NEXT:    [[TMP2:%.*]] = add nsw i32 [[TMP1]], [[B]]
+; IS__TUNIT____-NEXT:    ret i32 [[TMP2]]
+;
+; IS__CGSCC____-LABEL: define {{[^@]+}}@test2
+; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) [[ATTR0]] {
+; IS__CGSCC____-NEXT:    [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) [[ATTR1]], [[RNG0]]
+; IS__CGSCC____-NEXT:    [[TMP2:%.*]] = add nsw i32 [[TMP1]], [[B]]
+; IS__CGSCC____-NEXT:    ret i32 [[TMP2]]
+;
+  %1 = call i32 @test_range(i32 %unknown)
+  %2 = add nsw i32 %1, %b
+  ret i32 %2
+}
+
+; Positive checks
+
+define i32 @test1_pcheck(i32 %unknown) {
+; NOT_CGSCC_NPM-LABEL: define {{[^@]+}}@test1_pcheck
+; NOT_CGSCC_NPM-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0:#.*]] {
+; NOT_CGSCC_NPM-NEXT:    [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 noundef 20)
+; NOT_CGSCC_NPM-NEXT:    [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 90
+; NOT_CGSCC_NPM-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; NOT_CGSCC_NPM-NEXT:    ret i32 [[TMP3]]
+;
+; IS__CGSCC_NPM-LABEL: define {{[^@]+}}@test1_pcheck
+; IS__CGSCC_NPM-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0:#.*]] {
+; IS__CGSCC_NPM-NEXT:    [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 noundef 20) [[ATTR1:#.*]], [[RNG1:!range !.*]]
+; IS__CGSCC_NPM-NEXT:    [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 90
+; IS__CGSCC_NPM-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; IS__CGSCC_NPM-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = call i32 @test1(i32 %unknown, i32 20)
+  %2 = icmp sle i32 %1, 90
+  %3 = zext i1 %2 to i32
+  ret i32 %3
+}
+
+define i32 @test2_pcheck(i32 %unknown) {
+; CHECK-LABEL: define {{[^@]+}}@test2_pcheck
+; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @test2(i32 [[UNKNOWN]], i32 noundef 20)
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp sge i32 [[TMP1]], 20
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = call i32 @test2(i32 %unknown, i32 20)
+  %2 = icmp sge i32 %1, 20
+  %3 = zext i1 %2 to i32
+  ret i32 %3
+}
+
+; Negative checks
+
+define i32 @test1_ncheck(i32 %unknown) {
+; NOT_CGSCC_NPM-LABEL: define {{[^@]+}}@test1_ncheck
+; NOT_CGSCC_NPM-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; NOT_CGSCC_NPM-NEXT:    [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 noundef 20)
+; NOT_CGSCC_NPM-NEXT:    [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 10
+; NOT_CGSCC_NPM-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; NOT_CGSCC_NPM-NEXT:    ret i32 [[TMP3]]
+;
+; IS__CGSCC_NPM-LABEL: define {{[^@]+}}@test1_ncheck
+; IS__CGSCC_NPM-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; IS__CGSCC_NPM-NEXT:    [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 noundef 20) [[ATTR1]], [[RNG1]]
+; IS__CGSCC_NPM-NEXT:    [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 10
+; IS__CGSCC_NPM-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; IS__CGSCC_NPM-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = call i32 @test1(i32 %unknown, i32 20)
+  %2 = icmp sle i32 %1, 10
+  %3 = zext i1 %2 to i32
+  ret i32 %3
+}
+
+define i32 @test2_ncheck(i32 %unknown) {
+; CHECK-LABEL: define {{[^@]+}}@test2_ncheck
+; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i32 @test2(i32 [[UNKNOWN]], i32 noundef 20)
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp sge i32 [[TMP1]], 30
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = call i32 @test2(i32 %unknown, i32 20)
+  %2 = icmp sge i32 %1, 30
+  %3 = zext i1 %2 to i32
+  ret i32 %3
+}

diff  --git a/llvm/test/Transforms/Attributor/cb_range_enabled.ll b/llvm/test/Transforms/Attributor/cb_range_enabled.ll
new file mode 100644
index 000000000000..c07c370ab217
--- /dev/null
+++ b/llvm/test/Transforms/Attributor/cb_range_enabled.ll
@@ -0,0 +1,126 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --scrub-attributes
+; call site specific analysis is enabled
+
+; RUN: opt -attributor -enable-new-pm=0 -attributor-enable-call-site-specific-deduction=true -attributor-manifest-internal  -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=1 -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_CGSCC_NPM,NOT_CGSCC_OPM,NOT_TUNIT_NPM,IS__TUNIT____,IS________OPM,IS__TUNIT_OPM
+
+; RUN: opt -aa-pipeline=basic-aa -passes=attributor -attributor-enable-call-site-specific-deduction=true -attributor-manifest-internal  -attributor-max-iterations-verify -attributor-annotate-decl-cs -attributor-max-iterations=1 -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_CGSCC_OPM,NOT_CGSCC_NPM,NOT_TUNIT_OPM,IS__TUNIT____,IS________NPM,IS__TUNIT_NPM
+
+; RUN: opt -attributor-cgscc -attributor-enable-call-site-specific-deduction=true -enable-new-pm=0 -attributor-manifest-internal  -attributor-annotate-decl-cs -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_TUNIT_NPM,NOT_TUNIT_OPM,NOT_CGSCC_NPM,IS__CGSCC____,IS________OPM,IS__CGSCC_OPM
+
+; RUN: opt -aa-pipeline=basic-aa -passes=attributor-cgscc -attributor-enable-call-site-specific-deduction=true -attributor-manifest-internal  -attributor-annotate-decl-cs -S < %s | FileCheck %s --check-prefixes=CHECK,NOT_TUNIT_NPM,NOT_TUNIT_OPM,NOT_CGSCC_OPM,IS__CGSCC____,IS________NPM,IS__CGSCC_NPM
+
+define i32 @test_range(i32 %unknown) {
+; CHECK-LABEL: define {{[^@]+}}@test_range
+; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0:#.*]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp sgt i32 [[UNKNOWN]], 100
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[TMP1]], i32 100, i32 0
+; CHECK-NEXT:    ret i32 [[TMP2]]
+;
+  %1 = icmp sgt i32 %unknown, 100
+  %2 = select i1 %1, i32 100, i32 0
+  ret i32 %2
+}
+
+define i32 @test1(i32 %unknown, i32 %b) {
+; IS__TUNIT____-LABEL: define {{[^@]+}}@test1
+; IS__TUNIT____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) [[ATTR0:#.*]] {
+; IS__TUNIT____-NEXT:    [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) [[ATTR0]], [[RNG0:!range !.*]]
+; IS__TUNIT____-NEXT:    [[TMP2:%.*]] = sub nsw i32 [[TMP1]], [[B]]
+; IS__TUNIT____-NEXT:    ret i32 [[TMP2]]
+;
+; IS__CGSCC____-LABEL: define {{[^@]+}}@test1
+; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) [[ATTR0:#.*]] {
+; IS__CGSCC____-NEXT:    [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) [[ATTR1:#.*]], [[RNG0:!range !.*]]
+; IS__CGSCC____-NEXT:    [[TMP2:%.*]] = sub nsw i32 [[TMP1]], [[B]]
+; IS__CGSCC____-NEXT:    ret i32 [[TMP2]]
+;
+  %1 = call i32 @test_range(i32 %unknown)
+  %2 = sub nsw i32 %1, %b
+  ret i32 %2
+}
+
+define i32 @test2(i32 %unknown, i32 %b) {
+; IS__TUNIT____-LABEL: define {{[^@]+}}@test2
+; IS__TUNIT____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) [[ATTR0]] {
+; IS__TUNIT____-NEXT:    [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) [[ATTR0]], [[RNG0]]
+; IS__TUNIT____-NEXT:    [[TMP2:%.*]] = add nsw i32 [[TMP1]], [[B]]
+; IS__TUNIT____-NEXT:    ret i32 [[TMP2]]
+;
+; IS__CGSCC____-LABEL: define {{[^@]+}}@test2
+; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]], i32 [[B:%.*]]) [[ATTR0]] {
+; IS__CGSCC____-NEXT:    [[TMP1:%.*]] = call i32 @test_range(i32 [[UNKNOWN]]) [[ATTR1]], [[RNG0]]
+; IS__CGSCC____-NEXT:    [[TMP2:%.*]] = add nsw i32 [[TMP1]], [[B]]
+; IS__CGSCC____-NEXT:    ret i32 [[TMP2]]
+;
+  %1 = call i32 @test_range(i32 %unknown)
+  %2 = add nsw i32 %1, %b
+  ret i32 %2
+}
+
+; Positive checks
+
+define i32 @test1_pcheck(i32 %unknown) {
+; CHECK-LABEL: define {{[^@]+}}@test1_pcheck
+; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; CHECK-NEXT:    ret i32 1
+;
+  %1 = call i32 @test1(i32 %unknown, i32 20)
+  %2 = icmp sle i32 %1, 90
+  %3 = zext i1 %2 to i32
+  ret i32 %3
+}
+
+define i32 @test2_pcheck(i32 %unknown) {
+; CHECK-LABEL: define {{[^@]+}}@test2_pcheck
+; CHECK-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; CHECK-NEXT:    ret i32 1
+;
+  %1 = call i32 @test2(i32 %unknown, i32 20)
+  %2 = icmp sge i32 %1, 20
+  %3 = zext i1 %2 to i32
+  ret i32 %3
+}
+
+; Negative checks
+
+define i32 @test1_ncheck(i32 %unknown) {
+; IS__TUNIT____-LABEL: define {{[^@]+}}@test1_ncheck
+; IS__TUNIT____-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; IS__TUNIT____-NEXT:    [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 noundef 20) [[ATTR0]], [[RNG1:!range !.*]]
+; IS__TUNIT____-NEXT:    [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 10
+; IS__TUNIT____-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; IS__TUNIT____-NEXT:    ret i32 [[TMP3]]
+;
+; IS__CGSCC____-LABEL: define {{[^@]+}}@test1_ncheck
+; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; IS__CGSCC____-NEXT:    [[TMP1:%.*]] = call i32 @test1(i32 [[UNKNOWN]], i32 noundef 20) [[ATTR1]], [[RNG1:!range !.*]]
+; IS__CGSCC____-NEXT:    [[TMP2:%.*]] = icmp sle i32 [[TMP1]], 10
+; IS__CGSCC____-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; IS__CGSCC____-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = call i32 @test1(i32 %unknown, i32 20)
+  %2 = icmp sle i32 %1, 10
+  %3 = zext i1 %2 to i32
+  ret i32 %3
+}
+
+define i32 @test2_ncheck(i32 %unknown) {
+; IS__TUNIT____-LABEL: define {{[^@]+}}@test2_ncheck
+; IS__TUNIT____-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; IS__TUNIT____-NEXT:    [[TMP1:%.*]] = call i32 @test2(i32 [[UNKNOWN]], i32 noundef 20) [[ATTR0]], [[RNG2:!range !.*]]
+; IS__TUNIT____-NEXT:    [[TMP2:%.*]] = icmp sge i32 [[TMP1]], 30
+; IS__TUNIT____-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; IS__TUNIT____-NEXT:    ret i32 [[TMP3]]
+;
+; IS__CGSCC____-LABEL: define {{[^@]+}}@test2_ncheck
+; IS__CGSCC____-SAME: (i32 [[UNKNOWN:%.*]]) [[ATTR0]] {
+; IS__CGSCC____-NEXT:    [[TMP1:%.*]] = call i32 @test2(i32 [[UNKNOWN]], i32 noundef 20) [[ATTR1]], [[RNG2:!range !.*]]
+; IS__CGSCC____-NEXT:    [[TMP2:%.*]] = icmp sge i32 [[TMP1]], 30
+; IS__CGSCC____-NEXT:    [[TMP3:%.*]] = zext i1 [[TMP2]] to i32
+; IS__CGSCC____-NEXT:    ret i32 [[TMP3]]
+;
+  %1 = call i32 @test2(i32 %unknown, i32 20)
+  %2 = icmp sge i32 %1, 30
+  %3 = zext i1 %2 to i32
+  ret i32 %3
+}

diff  --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
index 7d496ed0d14f..e16fed11a375 100644
--- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp
+++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
@@ -21,6 +21,22 @@
 
 namespace llvm {
 
+TEST_F(AttributorTestBase, IRPPositionCallBaseContext) {
+  const char *ModuleString = R"(
+    define i32 @foo(i32 %a) {
+    entry:
+      ret i32 %a
+    }
+  )";
+
+  parseModule(ModuleString);
+
+  Function *F = M->getFunction("foo");
+  IRPosition Pos = IRPosition::function(*F, (const llvm::CallBase *)0xDEADBEEF);
+  EXPECT_TRUE(Pos.hasCallBaseContext());
+  EXPECT_FALSE(Pos.stripCallBaseContext().hasCallBaseContext());
+}
+
 TEST_F(AttributorTestBase, TestCast) {
   const char *ModuleString = R"(
     define i32 @foo(i32 %a, i32 %b) {


        


More information about the llvm-commits mailing list