[llvm] r373985 - [Attributor] Use abstract call sites for call site callback

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 7 16:14:58 PDT 2019


Author: jdoerfert
Date: Mon Oct  7 16:14:58 2019
New Revision: 373985

URL: http://llvm.org/viewvc/llvm-project?rev=373985&view=rev
Log:
[Attributor] Use abstract call sites for call site callback

Summary:
When we iterate over uses of functions and expect them to be call sites,
we now use abstract call sites to allow callback calls.

Reviewers: sstefan1, uenoku

Subscribers: hiraditya, bollu, hfinkel, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D67871

Added:
    llvm/trunk/test/Transforms/FunctionAttrs/callbacks.ll
Modified:
    llvm/trunk/include/llvm/IR/CallSite.h
    llvm/trunk/include/llvm/Transforms/IPO/Attributor.h
    llvm/trunk/lib/Transforms/IPO/Attributor.cpp

Modified: llvm/trunk/include/llvm/IR/CallSite.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/CallSite.h?rev=373985&r1=373984&r2=373985&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/CallSite.h (original)
+++ llvm/trunk/include/llvm/IR/CallSite.h Mon Oct  7 16:14:58 2019
@@ -854,6 +854,15 @@ public:
     return CI.ParameterEncoding[0];
   }
 
+  /// Return the use of the callee value in the underlying instruction. Only
+  /// valid for callback calls!
+  const Use &getCalleeUseForCallback() const {
+    int CalleeArgIdx = getCallArgOperandNoForCallee();
+    assert(CalleeArgIdx >= 0 &&
+           unsigned(CalleeArgIdx) < getInstruction()->getNumOperands());
+    return getInstruction()->getOperandUse(CalleeArgIdx);
+  }
+
   /// Return the pointer to function that is being called.
   Value *getCalledValue() const {
     if (isDirectCall())

Modified: llvm/trunk/include/llvm/Transforms/IPO/Attributor.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/IPO/Attributor.h?rev=373985&r1=373984&r2=373985&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Transforms/IPO/Attributor.h (original)
+++ llvm/trunk/include/llvm/Transforms/IPO/Attributor.h Mon Oct  7 16:14:58 2019
@@ -216,6 +216,16 @@ struct IRPosition {
                                          ArgNo);
   }
 
+  /// Create a position describing the argument of \p ACS at position \p ArgNo.
+  static const IRPosition callsite_argument(AbstractCallSite ACS,
+                                            unsigned ArgNo) {
+    int CSArgNo = ACS.getCallArgOperandNo(ArgNo);
+    if (CSArgNo >= 0)
+      return IRPosition::callsite_argument(
+          cast<CallBase>(*ACS.getInstruction()), CSArgNo);
+    return IRPosition();
+  }
+
   /// Create a position with function scope matching the "context" of \p IRP.
   /// If \p IRP is a call site (see isAnyCallSitePosition()) then the result
   /// will be a call site position, otherwise the function position of the
@@ -825,7 +835,7 @@ struct Attributor {
   /// This method will evaluate \p Pred on call sites and return
   /// true if \p Pred holds in every call sites. However, this is only possible
   /// all call sites are known, hence the function has internal linkage.
-  bool checkForAllCallSites(const function_ref<bool(CallSite)> &Pred,
+  bool checkForAllCallSites(const function_ref<bool(AbstractCallSite)> &Pred,
                             const AbstractAttribute &QueryingAA,
                             bool RequireAllCallSites);
 

Modified: llvm/trunk/lib/Transforms/IPO/Attributor.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/IPO/Attributor.cpp?rev=373985&r1=373984&r2=373985&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/IPO/Attributor.cpp (original)
+++ llvm/trunk/lib/Transforms/IPO/Attributor.cpp Mon Oct  7 16:14:58 2019
@@ -596,11 +596,16 @@ static void clampCallSiteArgumentStates(
   // The argument number which is also the call site argument number.
   unsigned ArgNo = QueryingAA.getIRPosition().getArgNo();
 
-  auto CallSiteCheck = [&](CallSite CS) {
-    const IRPosition &CSArgPos = IRPosition::callsite_argument(CS, ArgNo);
-    const AAType &AA = A.getAAFor<AAType>(QueryingAA, CSArgPos);
-    LLVM_DEBUG(dbgs() << "[Attributor] CS: " << *CS.getInstruction()
-                      << " AA: " << AA.getAsStr() << " @" << CSArgPos << "\n");
+  auto CallSiteCheck = [&](AbstractCallSite ACS) {
+    const IRPosition &ACSArgPos = IRPosition::callsite_argument(ACS, ArgNo);
+    // Check if a coresponding argument was found or if it is on not associated
+    // (which can happen for callback calls).
+    if (ACSArgPos.getPositionKind() == IRPosition::IRP_INVALID)
+      return false;
+
+    const AAType &AA = A.getAAFor<AAType>(QueryingAA, ACSArgPos);
+    LLVM_DEBUG(dbgs() << "[Attributor] ACS: " << *ACS.getInstruction()
+                      << " AA: " << AA.getAsStr() << " @" << ACSArgPos << "\n");
     const StateType &AAS = static_cast<const StateType &>(AA.getState());
     if (T.hasValue())
       *T &= AAS;
@@ -3100,9 +3105,12 @@ struct AAValueSimplifyArgument final : A
   ChangeStatus updateImpl(Attributor &A) override {
     bool HasValueBefore = SimplifiedAssociatedValue.hasValue();
 
-    auto PredForCallSite = [&](CallSite CS) {
-      return checkAndUpdate(A, *this, *CS.getArgOperand(getArgNo()),
-                            SimplifiedAssociatedValue);
+    auto PredForCallSite = [&](AbstractCallSite ACS) {
+      // Check if we have an associated argument or not (which can happen for
+      // callback calls).
+      if (Value *ArgOp = ACS.getCallArgOperand(getArgNo()))
+        return checkAndUpdate(A, *this, *ArgOp, SimplifiedAssociatedValue);
+      return false;
     };
 
     if (!A.checkForAllCallSites(PredForCallSite, *this, true))
@@ -3914,9 +3922,9 @@ bool Attributor::isAssumedDead(const Abs
   return true;
 }
 
-bool Attributor::checkForAllCallSites(const function_ref<bool(CallSite)> &Pred,
-                                      const AbstractAttribute &QueryingAA,
-                                      bool RequireAllCallSites) {
+bool Attributor::checkForAllCallSites(
+    const function_ref<bool(AbstractCallSite)> &Pred,
+    const AbstractAttribute &QueryingAA, bool RequireAllCallSites) {
   // We can try to determine information from
   // the call sites. However, this is only possible all call sites are known,
   // hence the function has internal linkage.
@@ -3934,15 +3942,21 @@ bool Attributor::checkForAllCallSites(co
   }
 
   for (const Use &U : AssociatedFunction->uses()) {
-    Instruction *I = dyn_cast<Instruction>(U.getUser());
-    // TODO: Deal with abstract call sites here.
-    if (!I)
+    AbstractCallSite ACS(&U);
+    if (!ACS) {
+      LLVM_DEBUG(dbgs() << "[Attributor] Function "
+                        << AssociatedFunction->getName()
+                        << " has non call site use " << *U.get() << " in "
+                        << *U.getUser() << "\n");
       return false;
+    }
 
+    Instruction *I = ACS.getInstruction();
     Function *Caller = I->getFunction();
 
-    const auto &LivenessAA = getAAFor<AAIsDead>(
-        QueryingAA, IRPosition::function(*Caller), /* TrackDependence */ false);
+    const auto &LivenessAA =
+        getAAFor<AAIsDead>(QueryingAA, IRPosition::function(*Caller),
+                           /* TrackDependence */ false);
 
     // Skip dead calls.
     if (LivenessAA.isAssumedDead(I)) {
@@ -3952,22 +3966,22 @@ bool Attributor::checkForAllCallSites(co
       continue;
     }
 
-    CallSite CS(U.getUser());
-    if (!CS || !CS.isCallee(&U)) {
+    const Use *EffectiveUse =
+        ACS.isCallbackCall() ? &ACS.getCalleeUseForCallback() : &U;
+    if (!ACS.isCallee(EffectiveUse)) {
       if (!RequireAllCallSites)
         continue;
-
-      LLVM_DEBUG(dbgs() << "[Attributor] User " << *U.getUser()
+      LLVM_DEBUG(dbgs() << "[Attributor] User " << EffectiveUse->getUser()
                         << " is an invalid use of "
                         << AssociatedFunction->getName() << "\n");
       return false;
     }
 
-    if (Pred(CS))
+    if (Pred(ACS))
       continue;
 
     LLVM_DEBUG(dbgs() << "[Attributor] Call site callback failed for "
-                      << *CS.getInstruction() << "\n");
+                      << *ACS.getInstruction() << "\n");
     return false;
   }
 
@@ -4319,7 +4333,7 @@ ChangeStatus Attributor::run(Module &M)
         const auto *LivenessAA =
             lookupAAFor<AAIsDead>(IRPosition::function(*F));
         if (LivenessAA &&
-            !checkForAllCallSites([](CallSite CS) { return false; },
+            !checkForAllCallSites([](AbstractCallSite ACS) { return false; },
                                   *LivenessAA, true))
           continue;
 

Added: llvm/trunk/test/Transforms/FunctionAttrs/callbacks.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/FunctionAttrs/callbacks.ll?rev=373985&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/FunctionAttrs/callbacks.ll (added)
+++ llvm/trunk/test/Transforms/FunctionAttrs/callbacks.ll Mon Oct  7 16:14:58 2019
@@ -0,0 +1,63 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -passes=attributor -aa-pipeline='basic-aa' -attributor-disable=false -attributor-max-iterations-verify -attributor-max-iterations=1 < %s | FileCheck %s
+; ModuleID = 'callback_simple.c'
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+
+; Test 0
+;
+; Make sure we propagate information from the caller to the callback callee but
+; only for arguments that are mapped through the callback metadata. Here, the
+; first two arguments of the call and the callback callee do not correspond to
+; each other but argument 3-5 of the transitive call site in the caller match
+; arguments 2-4 of the callback callee. Here we should see information and value
+; transfer in both directions.
+; FIXME: The callee -> call site direction is not working yet.
+
+define void @t0_caller(i32* %a) {
+; CHECK:       @t0_caller(i32* [[A:%.*]])
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[B:%.*]] = alloca i32, align 32
+; CHECK-NEXT:    [[C:%.*]] = alloca i32*, align 64
+; CHECK-NEXT:    [[PTR:%.*]] = alloca i32, align 128
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i32* [[B]] to i8*
+; CHECK-NEXT:    store i32 42, i32* [[B]], align 32
+; CHECK-NEXT:    store i32* [[B]], i32** [[C]], align 64
+; CHECK-NEXT:    call void (i32*, i32*, void (i32*, i32*, ...)*, ...) @t0_callback_broker(i32* null, i32* nonnull align 128 dereferenceable(4) [[PTR]], void (i32*, i32*, ...)* nonnull bitcast (void (i32*, i32*, i32*, i64, i32**)* @t0_callback_callee to void (i32*, i32*, ...)*), i32* [[A:%.*]], i64 99, i32** nonnull align 64 dereferenceable(8) [[C]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %b = alloca i32, align 32
+  %c = alloca i32*, align 64
+  %ptr = alloca i32, align 128
+  %0 = bitcast i32* %b to i8*
+  store i32 42, i32* %b, align 4
+  store i32* %b, i32** %c, align 8
+  call void (i32*, i32*, void (i32*, i32*, ...)*, ...) @t0_callback_broker(i32* null, i32* %ptr, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*, i64, i32**)* @t0_callback_callee to void (i32*, i32*, ...)*), i32* %a, i64 99, i32** %c)
+  ret void
+}
+
+; Note that the first two arguments are provided by the callback_broker according to the callback in !1 below!
+; The others are annotated with alignment information, amongst others, or even replaced by the constants passed to the call.
+define internal void @t0_callback_callee(i32* %is_not_null, i32* %ptr, i32* %a, i64 %b, i32** %c) {
+; CHECK:       @t0_callback_callee(i32* nocapture writeonly [[IS_NOT_NULL:%.*]], i32* nocapture readonly [[PTR:%.*]], i32* [[A:%.*]], i64 [[B:%.*]], i32** nocapture nonnull readonly align 64 dereferenceable(8) [[C:%.*]])
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[PTR_VAL:%.*]] = load i32, i32* [[PTR:%.*]], align 8
+; CHECK-NEXT:    store i32 [[PTR_VAL]], i32* [[IS_NOT_NULL:%.*]]
+; CHECK-NEXT:    [[TMP0:%.*]] = load i32*, i32** [[C:%.*]], align 64
+; CHECK-NEXT:    tail call void @t0_check(i32* align 256 [[A:%.*]], i64 99, i32* [[TMP0]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %ptr_val = load i32, i32* %ptr, align 8
+  store i32 %ptr_val, i32* %is_not_null
+  %0 = load i32*, i32** %c, align 8
+  tail call void @t0_check(i32* %a, i64 %b, i32* %0)
+  ret void
+}
+
+declare void @t0_check(i32* align 256, i64, i32*)
+
+declare !callback !0 void @t0_callback_broker(i32*, i32*, void (i32*, i32*, ...)*, ...)
+
+!0 = !{!1}
+!1 = !{i64 2, i64 -1, i64 -1, i1 true}




More information about the llvm-commits mailing list